mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-09 07:42:39 +02:00
chore: resolve merge conflict by removing legacy all-in-one Docker files
Keeps the deletion of Dockerfile.allinone, docker-compose.yml (root), and scripts/docker/entrypoint-allinone.sh from fix/docker. Ports the Daytona sandbox env vars added by upstream/dev into docker/docker-compose.yml and docker/docker-compose.dev.yml instead. Made-with: Cursor
This commit is contained in:
commit
b5874a587a
126 changed files with 17384 additions and 9088 deletions
112
.cursor/skills/tdd/SKILL.md
Normal file
112
.cursor/skills/tdd/SKILL.md
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
---
|
||||
name: tdd
|
||||
description: Strict Python TDD workflow using pytest (Red-Green-Refactor).
|
||||
---
|
||||
|
||||
---
|
||||
name: tdd
|
||||
description: Test-driven development with red-green-refactor loop. Use when user wants to build features or fix bugs using TDD, mentions "red-green-refactor", wants integration tests, or asks for test-first development.
|
||||
---
|
||||
|
||||
# Test-Driven Development
|
||||
|
||||
## Philosophy
|
||||
|
||||
**Core principle**: Tests should verify behavior through public interfaces, not implementation details. Code can change entirely; tests shouldn't.
|
||||
|
||||
**Good tests** are integration-style: they exercise real code paths through public APIs. They describe _what_ the system does, not _how_ it does it. A good test reads like a specification - "user can checkout with valid cart" tells you exactly what capability exists. These tests survive refactors because they don't care about internal structure.
|
||||
|
||||
**Bad tests** are coupled to implementation. They mock internal collaborators, test private methods, or verify through external means (like querying a database directly instead of using the interface). The warning sign: your test breaks when you refactor, but behavior hasn't changed. If you rename an internal function and tests fail, those tests were testing implementation, not behavior.
|
||||
|
||||
See [tests.md](tests.md) for examples and [mocking.md](mocking.md) for mocking guidelines.
|
||||
|
||||
## Anti-Pattern: Horizontal Slices
|
||||
|
||||
**DO NOT write all tests first, then all implementation.** This is "horizontal slicing" - treating RED as "write all tests" and GREEN as "write all code."
|
||||
|
||||
This produces **crap tests**:
|
||||
|
||||
- Tests written in bulk test _imagined_ behavior, not _actual_ behavior
|
||||
- You end up testing the _shape_ of things (data structures, function signatures) rather than user-facing behavior
|
||||
- Tests become insensitive to real changes - they pass when behavior breaks, fail when behavior is fine
|
||||
- You outrun your headlights, committing to test structure before understanding the implementation
|
||||
|
||||
**Correct approach**: Vertical slices via tracer bullets. One test → one implementation → repeat. Each test responds to what you learned from the previous cycle. Because you just wrote the code, you know exactly what behavior matters and how to verify it.
|
||||
|
||||
```
|
||||
WRONG (horizontal):
|
||||
RED: test1, test2, test3, test4, test5
|
||||
GREEN: impl1, impl2, impl3, impl4, impl5
|
||||
|
||||
RIGHT (vertical):
|
||||
RED→GREEN: test1→impl1
|
||||
RED→GREEN: test2→impl2
|
||||
RED→GREEN: test3→impl3
|
||||
...
|
||||
```
|
||||
|
||||
## Workflow
|
||||
|
||||
### 1. Planning
|
||||
|
||||
Before writing any code:
|
||||
|
||||
- [ ] Confirm with user what interface changes are needed
|
||||
- [ ] Confirm with user which behaviors to test (prioritize)
|
||||
- [ ] Identify opportunities for [deep modules](deep-modules.md) (small interface, deep implementation)
|
||||
- [ ] Design interfaces for [testability](interface-design.md)
|
||||
- [ ] List the behaviors to test (not implementation steps)
|
||||
- [ ] Get user approval on the plan
|
||||
|
||||
Ask: "What should the public interface look like? Which behaviors are most important to test?"
|
||||
|
||||
**You can't test everything.** Confirm with the user exactly which behaviors matter most. Focus testing effort on critical paths and complex logic, not every possible edge case.
|
||||
|
||||
### 2. Tracer Bullet
|
||||
|
||||
Write ONE test that confirms ONE thing about the system:
|
||||
|
||||
```
|
||||
RED: Write test for first behavior → test fails
|
||||
GREEN: Write minimal code to pass → test passes
|
||||
```
|
||||
|
||||
This is your tracer bullet - proves the path works end-to-end.
|
||||
|
||||
### 3. Incremental Loop
|
||||
|
||||
For each remaining behavior:
|
||||
|
||||
```
|
||||
RED: Write next test → fails
|
||||
GREEN: Minimal code to pass → passes
|
||||
```
|
||||
|
||||
Rules:
|
||||
|
||||
- One test at a time
|
||||
- Only enough code to pass current test
|
||||
- Don't anticipate future tests
|
||||
- Keep tests focused on observable behavior
|
||||
|
||||
### 4. Refactor
|
||||
|
||||
After all tests pass, look for [refactor candidates](refactoring.md):
|
||||
|
||||
- [ ] Extract duplication
|
||||
- [ ] Deepen modules (move complexity behind simple interfaces)
|
||||
- [ ] Apply SOLID principles where natural
|
||||
- [ ] Consider what new code reveals about existing code
|
||||
- [ ] Run tests after each refactor step
|
||||
|
||||
**Never refactor while RED.** Get to GREEN first.
|
||||
|
||||
## Checklist Per Cycle
|
||||
|
||||
```
|
||||
[ ] Test describes behavior, not implementation
|
||||
[ ] Test uses public interface only
|
||||
[ ] Test would survive internal refactor
|
||||
[ ] Code is minimal for this test
|
||||
[ ] No speculative features added
|
||||
```
|
||||
33
.cursor/skills/tdd/deep-modules.md
Normal file
33
.cursor/skills/tdd/deep-modules.md
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
# Deep Modules
|
||||
|
||||
From "A Philosophy of Software Design":
|
||||
|
||||
**Deep module** = small interface + lots of implementation
|
||||
|
||||
```
|
||||
┌─────────────────────┐
|
||||
│ Small Interface │ ← Few methods, simple params
|
||||
├─────────────────────┤
|
||||
│ │
|
||||
│ │
|
||||
│ Deep Implementation│ ← Complex logic hidden
|
||||
│ │
|
||||
│ │
|
||||
└─────────────────────┘
|
||||
```
|
||||
|
||||
**Shallow module** = large interface + little implementation (avoid)
|
||||
|
||||
```
|
||||
┌─────────────────────────────────┐
|
||||
│ Large Interface │ ← Many methods, complex params
|
||||
├─────────────────────────────────┤
|
||||
│ Thin Implementation │ ← Just passes through
|
||||
└─────────────────────────────────┘
|
||||
```
|
||||
|
||||
When designing interfaces, ask:
|
||||
|
||||
- Can I reduce the number of methods?
|
||||
- Can I simplify the parameters?
|
||||
- Can I hide more complexity inside?
|
||||
33
.cursor/skills/tdd/interface-design.md
Normal file
33
.cursor/skills/tdd/interface-design.md
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
# Interface Design for Testability
|
||||
|
||||
Good interfaces make testing natural:
|
||||
|
||||
1. **Accept dependencies, don't create them**
|
||||
```python
|
||||
# Testable
|
||||
def process_order(order, payment_gateway):
|
||||
pass
|
||||
|
||||
# Hard to test
|
||||
def process_order(order):
|
||||
gateway = StripeGateway()
|
||||
|
||||
```
|
||||
|
||||
|
||||
2. **Return results, don't produce side effects**
|
||||
```python
|
||||
# Testable
|
||||
def calculate_discount(cart) -> float:
|
||||
return discount
|
||||
|
||||
# Hard to test
|
||||
def apply_discount(cart) -> None:
|
||||
cart.total -= discount
|
||||
|
||||
```
|
||||
|
||||
|
||||
3. **Small surface area**
|
||||
* Fewer methods = fewer tests needed
|
||||
* Fewer params = simpler test setup
|
||||
69
.cursor/skills/tdd/mocking.md
Normal file
69
.cursor/skills/tdd/mocking.md
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
|
||||
# When to Mock
|
||||
|
||||
Mock at **system boundaries** only:
|
||||
|
||||
* External APIs (payment, email, etc.)
|
||||
* Databases (sometimes - prefer test DB)
|
||||
* Time/randomness
|
||||
* File system (sometimes)
|
||||
|
||||
Don't mock:
|
||||
|
||||
* Your own classes/modules
|
||||
* Internal collaborators
|
||||
* Anything you control
|
||||
|
||||
## Designing for Mockability
|
||||
|
||||
At system boundaries, design interfaces that are easy to mock:
|
||||
|
||||
**1. Use dependency injection**
|
||||
|
||||
Pass external dependencies in rather than creating them internally:
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
# Easy to mock
|
||||
def process_payment(order, payment_client):
|
||||
return payment_client.charge(order.total)
|
||||
|
||||
# Hard to mock
|
||||
def process_payment(order):
|
||||
client = StripeClient(os.getenv("STRIPE_KEY"))
|
||||
return client.charge(order.total)
|
||||
|
||||
```
|
||||
|
||||
**2. Prefer SDK-style interfaces over generic fetchers**
|
||||
|
||||
Create specific functions for each external operation instead of one generic function with conditional logic:
|
||||
|
||||
```python
|
||||
import requests
|
||||
|
||||
# GOOD: Each function is independently mockable
|
||||
class UserAPI:
|
||||
def get_user(self, user_id):
|
||||
return requests.get(f"/users/{user_id}")
|
||||
|
||||
def get_orders(self, user_id):
|
||||
return requests.get(f"/users/{user_id}/orders")
|
||||
|
||||
def create_order(self, data):
|
||||
return requests.post("/orders", json=data)
|
||||
|
||||
# BAD: Mocking requires conditional logic inside the mock
|
||||
class GenericAPI:
|
||||
def fetch(self, endpoint, method="GET", data=None):
|
||||
return requests.request(method, endpoint, json=data)
|
||||
|
||||
```
|
||||
|
||||
The SDK approach means:
|
||||
|
||||
* Each mock returns one specific shape
|
||||
* No conditional logic in test setup
|
||||
* Easier to see which endpoints a test exercises
|
||||
* Type safety per endpoint
|
||||
10
.cursor/skills/tdd/refactoring.md
Normal file
10
.cursor/skills/tdd/refactoring.md
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
# Refactor Candidates
|
||||
|
||||
After TDD cycle, look for:
|
||||
|
||||
- **Duplication** → Extract function/class
|
||||
- **Long methods** → Break into private helpers (keep tests on public interface)
|
||||
- **Shallow modules** → Combine or deepen
|
||||
- **Feature envy** → Move logic to where data lives
|
||||
- **Primitive obsession** → Introduce value objects
|
||||
- **Existing code** the new code reveals as problematic
|
||||
60
.cursor/skills/tdd/tests.md
Normal file
60
.cursor/skills/tdd/tests.md
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
# Good and Bad Tests
|
||||
|
||||
## Good Tests
|
||||
|
||||
**Integration-style**: Test through real interfaces, not mocks of internal parts.
|
||||
|
||||
```python
|
||||
# GOOD: Tests observable behavior
|
||||
def test_user_can_checkout_with_valid_cart():
|
||||
cart = create_cart()
|
||||
cart.add(product)
|
||||
result = checkout(cart, payment_method)
|
||||
assert result.status == "confirmed"
|
||||
|
||||
```
|
||||
|
||||
Characteristics:
|
||||
|
||||
* Tests behavior users/callers care about
|
||||
* Uses public API only
|
||||
* Survives internal refactors
|
||||
* Describes WHAT, not HOW
|
||||
* One logical assertion per test
|
||||
|
||||
## Bad Tests
|
||||
|
||||
**Implementation-detail tests**: Coupled to internal structure.
|
||||
|
||||
```python
|
||||
# BAD: Tests implementation details
|
||||
def test_checkout_calls_payment_service_process():
|
||||
mock_payment = MagicMock()
|
||||
checkout(cart, mock_payment)
|
||||
mock_payment.process.assert_called_with(cart.total)
|
||||
|
||||
```
|
||||
|
||||
Red flags:
|
||||
|
||||
* Mocking internal collaborators
|
||||
* Testing private methods
|
||||
* Asserting on call counts/order
|
||||
* Test breaks when refactoring without behavior change
|
||||
* Test name describes HOW not WHAT
|
||||
* Verifying through external means instead of interface
|
||||
|
||||
```python
|
||||
# BAD: Bypasses interface to verify
|
||||
def test_create_user_saves_to_database():
|
||||
create_user({"name": "Alice"})
|
||||
row = db.query("SELECT * FROM users WHERE name = ?", ["Alice"])
|
||||
assert row is not None
|
||||
|
||||
# GOOD: Verifies through interface
|
||||
def test_create_user_makes_user_retrievable():
|
||||
user = create_user({"name": "Alice"})
|
||||
retrieved = get_user(user.id)
|
||||
assert retrieved.name == "Alice"
|
||||
|
||||
```
|
||||
|
|
@ -167,37 +167,21 @@ LANGSMITH_ENDPOINT=https://api.smith.langchain.com
|
|||
LANGSMITH_API_KEY=lsv2_pt_.....
|
||||
LANGSMITH_PROJECT=surfsense
|
||||
|
||||
# Uvicorn Server Configuration
|
||||
# Full documentation for Uvicorn options can be found at: https://www.uvicorn.org/#command-line-options
|
||||
UVICORN_HOST="0.0.0.0"
|
||||
UVICORN_PORT=8000
|
||||
UVICORN_LOG_LEVEL=info
|
||||
# Agent Specific Configuration
|
||||
# Daytona Sandbox (secure cloud code execution for deep agent)
|
||||
# Set DAYTONA_SANDBOX_ENABLED=TRUE to give the agent an isolated execute tool
|
||||
DAYTONA_SANDBOX_ENABLED=TRUE
|
||||
DAYTONA_API_KEY=dtn_asdasfasfafas
|
||||
DAYTONA_API_URL=https://app.daytona.io/api
|
||||
DAYTONA_TARGET=us
|
||||
# Directory for locally-persisted sandbox files (after sandbox deletion)
|
||||
SANDBOX_FILES_DIR=sandbox_files
|
||||
|
||||
# OPTIONAL: Advanced Uvicorn Options (uncomment to use)
|
||||
# UVICORN_PROXY_HEADERS=false
|
||||
# UVICORN_FORWARDED_ALLOW_IPS="127.0.0.1"
|
||||
# UVICORN_WORKERS=1
|
||||
# UVICORN_ACCESS_LOG=true
|
||||
# UVICORN_LOOP="auto"
|
||||
# UVICORN_HTTP="auto"
|
||||
# UVICORN_WS="auto"
|
||||
# UVICORN_LIFESPAN="auto"
|
||||
# UVICORN_LOG_CONFIG=""
|
||||
# UVICORN_SERVER_HEADER=true
|
||||
# UVICORN_DATE_HEADER=true
|
||||
# UVICORN_LIMIT_CONCURRENCY=
|
||||
# UVICORN_LIMIT_MAX_REQUESTS=
|
||||
# UVICORN_TIMEOUT_KEEP_ALIVE=5
|
||||
# UVICORN_TIMEOUT_NOTIFY=30
|
||||
# UVICORN_SSL_KEYFILE=""
|
||||
# UVICORN_SSL_CERTFILE=""
|
||||
# UVICORN_SSL_KEYFILE_PASSWORD=""
|
||||
# UVICORN_SSL_VERSION=""
|
||||
# UVICORN_SSL_CERT_REQS=""
|
||||
# UVICORN_SSL_CA_CERTS=""
|
||||
# UVICORN_SSL_CIPHERS=""
|
||||
# UVICORN_HEADERS=""
|
||||
# UVICORN_USE_COLORS=true
|
||||
# UVICORN_UDS=""
|
||||
# UVICORN_FD=""
|
||||
# UVICORN_ROOT_PATH=""
|
||||
|
||||
# ============================================================
|
||||
# Testing (optional — all have sensible defaults)
|
||||
# ============================================================
|
||||
# TEST_BACKEND_URL=http://localhost:8000
|
||||
# TEST_DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense
|
||||
# TEST_USER_EMAIL=testuser@surfsense.com
|
||||
# TEST_USER_PASSWORD=testpassword123
|
||||
|
|
|
|||
1
surfsense_backend/.gitignore
vendored
1
surfsense_backend/.gitignore
vendored
|
|
@ -6,6 +6,7 @@ __pycache__/
|
|||
.flashrank_cache
|
||||
surf_new_backend.egg-info/
|
||||
podcasts/
|
||||
sandbox_files/
|
||||
temp_audio/
|
||||
celerybeat-schedule*
|
||||
celerybeat-schedule.*
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from collections.abc import Sequence
|
|||
from typing import Any
|
||||
|
||||
from deepagents import create_deep_agent
|
||||
from deepagents.backends.protocol import SandboxBackendProtocol
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.types import Checkpointer
|
||||
|
|
@ -128,6 +129,7 @@ async def create_surfsense_deep_agent(
|
|||
additional_tools: Sequence[BaseTool] | None = None,
|
||||
firecrawl_api_key: str | None = None,
|
||||
thread_visibility: ChatVisibility | None = None,
|
||||
sandbox_backend: SandboxBackendProtocol | None = None,
|
||||
):
|
||||
"""
|
||||
Create a SurfSense deep agent with configurable tools and prompts.
|
||||
|
|
@ -167,6 +169,9 @@ async def create_surfsense_deep_agent(
|
|||
These are always added regardless of enabled/disabled settings.
|
||||
firecrawl_api_key: Optional Firecrawl API key for premium web scraping.
|
||||
Falls back to Chromium/Trafilatura if not provided.
|
||||
sandbox_backend: Optional sandbox backend (e.g. DaytonaSandbox) for
|
||||
secure code execution. When provided, the agent gets an
|
||||
isolated ``execute`` tool for running shell commands.
|
||||
|
||||
Returns:
|
||||
CompiledStateGraph: The configured deep agent
|
||||
|
|
@ -277,19 +282,26 @@ async def create_surfsense_deep_agent(
|
|||
)
|
||||
|
||||
# Build system prompt based on agent_config
|
||||
_sandbox_enabled = sandbox_backend is not None
|
||||
if agent_config is not None:
|
||||
# Use configurable prompt with settings from NewLLMConfig
|
||||
system_prompt = build_configurable_system_prompt(
|
||||
custom_system_instructions=agent_config.system_instructions,
|
||||
use_default_system_instructions=agent_config.use_default_system_instructions,
|
||||
citations_enabled=agent_config.citations_enabled,
|
||||
thread_visibility=thread_visibility,
|
||||
sandbox_enabled=_sandbox_enabled,
|
||||
)
|
||||
else:
|
||||
system_prompt = build_surfsense_system_prompt(
|
||||
thread_visibility=thread_visibility,
|
||||
sandbox_enabled=_sandbox_enabled,
|
||||
)
|
||||
|
||||
# Build optional kwargs for the deep agent
|
||||
deep_agent_kwargs: dict[str, Any] = {}
|
||||
if sandbox_backend is not None:
|
||||
deep_agent_kwargs["backend"] = sandbox_backend
|
||||
|
||||
# Create the deep agent with system prompt and checkpointer
|
||||
# Note: TodoListMiddleware (write_todos) is included by default in create_deep_agent
|
||||
agent = create_deep_agent(
|
||||
|
|
@ -298,6 +310,7 @@ async def create_surfsense_deep_agent(
|
|||
system_prompt=system_prompt,
|
||||
context_schema=SurfSenseContextSchema,
|
||||
checkpointer=checkpointer,
|
||||
**deep_agent_kwargs,
|
||||
)
|
||||
|
||||
return agent
|
||||
|
|
|
|||
266
surfsense_backend/app/agents/new_chat/sandbox.py
Normal file
266
surfsense_backend/app/agents/new_chat/sandbox.py
Normal file
|
|
@ -0,0 +1,266 @@
|
|||
"""
|
||||
Daytona sandbox provider for SurfSense deep agent.
|
||||
|
||||
Manages the lifecycle of sandboxed code execution environments.
|
||||
Each conversation thread gets its own isolated sandbox instance
|
||||
via the Daytona cloud API, identified by labels.
|
||||
|
||||
Files created during a session are persisted to local storage before
|
||||
the sandbox is deleted so they remain downloadable after cleanup.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from daytona import (
|
||||
CreateSandboxFromSnapshotParams,
|
||||
Daytona,
|
||||
DaytonaConfig,
|
||||
SandboxState,
|
||||
)
|
||||
from daytona.common.errors import DaytonaError
|
||||
from deepagents.backends.protocol import ExecuteResponse
|
||||
from langchain_daytona import DaytonaSandbox
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _TimeoutAwareSandbox(DaytonaSandbox):
|
||||
"""DaytonaSandbox subclass that accepts the per-command *timeout*
|
||||
kwarg required by the deepagents middleware.
|
||||
|
||||
The upstream ``langchain-daytona`` ``execute()`` ignores timeout,
|
||||
so deepagents raises *"This sandbox backend does not support
|
||||
per-command timeout overrides"* on every first call. This thin
|
||||
wrapper forwards the parameter to the Daytona SDK.
|
||||
"""
|
||||
|
||||
def execute(self, command: str, *, timeout: int | None = None) -> ExecuteResponse:
|
||||
t = timeout if timeout is not None else self._timeout
|
||||
result = self._sandbox.process.exec(command, timeout=t)
|
||||
return ExecuteResponse(
|
||||
output=result.result,
|
||||
exit_code=result.exit_code,
|
||||
truncated=False,
|
||||
)
|
||||
|
||||
async def aexecute(
|
||||
self, command: str, *, timeout: int | None = None
|
||||
) -> ExecuteResponse: # type: ignore[override]
|
||||
return await asyncio.to_thread(self.execute, command, timeout=timeout)
|
||||
|
||||
|
||||
_daytona_client: Daytona | None = None
|
||||
THREAD_LABEL_KEY = "surfsense_thread"
|
||||
|
||||
|
||||
def is_sandbox_enabled() -> bool:
|
||||
return os.environ.get("DAYTONA_SANDBOX_ENABLED", "FALSE").upper() == "TRUE"
|
||||
|
||||
|
||||
def _get_client() -> Daytona:
|
||||
global _daytona_client
|
||||
if _daytona_client is None:
|
||||
config = DaytonaConfig(
|
||||
api_key=os.environ.get("DAYTONA_API_KEY", ""),
|
||||
api_url=os.environ.get("DAYTONA_API_URL", "https://app.daytona.io/api"),
|
||||
target=os.environ.get("DAYTONA_TARGET", "us"),
|
||||
)
|
||||
_daytona_client = Daytona(config)
|
||||
return _daytona_client
|
||||
|
||||
|
||||
def _find_or_create(thread_id: str) -> _TimeoutAwareSandbox:
|
||||
"""Find an existing sandbox for *thread_id*, or create a new one.
|
||||
|
||||
If an existing sandbox is found but is stopped/archived, it will be
|
||||
restarted automatically before returning.
|
||||
"""
|
||||
client = _get_client()
|
||||
labels = {THREAD_LABEL_KEY: thread_id}
|
||||
|
||||
try:
|
||||
sandbox = client.find_one(labels=labels)
|
||||
logger.info("Found existing sandbox %s (state=%s)", sandbox.id, sandbox.state)
|
||||
|
||||
if sandbox.state in (
|
||||
SandboxState.STOPPED,
|
||||
SandboxState.STOPPING,
|
||||
SandboxState.ARCHIVED,
|
||||
):
|
||||
logger.info("Starting stopped sandbox %s …", sandbox.id)
|
||||
sandbox.start(timeout=60)
|
||||
logger.info("Sandbox %s is now started", sandbox.id)
|
||||
elif sandbox.state in (
|
||||
SandboxState.ERROR,
|
||||
SandboxState.BUILD_FAILED,
|
||||
SandboxState.DESTROYED,
|
||||
):
|
||||
logger.warning(
|
||||
"Sandbox %s in unrecoverable state %s — creating a new one",
|
||||
sandbox.id,
|
||||
sandbox.state,
|
||||
)
|
||||
sandbox = client.create(
|
||||
CreateSandboxFromSnapshotParams(language="python", labels=labels)
|
||||
)
|
||||
logger.info("Created replacement sandbox: %s", sandbox.id)
|
||||
elif sandbox.state != SandboxState.STARTED:
|
||||
sandbox.wait_for_sandbox_start(timeout=60)
|
||||
|
||||
except Exception:
|
||||
logger.info("No existing sandbox for thread %s — creating one", thread_id)
|
||||
sandbox = client.create(
|
||||
CreateSandboxFromSnapshotParams(language="python", labels=labels)
|
||||
)
|
||||
logger.info("Created new sandbox: %s", sandbox.id)
|
||||
|
||||
return _TimeoutAwareSandbox(sandbox=sandbox)
|
||||
|
||||
|
||||
async def get_or_create_sandbox(thread_id: int | str) -> _TimeoutAwareSandbox:
|
||||
"""Get or create a sandbox for a conversation thread.
|
||||
|
||||
Uses the thread_id as a label so the same sandbox persists
|
||||
across multiple messages within the same conversation.
|
||||
|
||||
Args:
|
||||
thread_id: The conversation thread identifier.
|
||||
|
||||
Returns:
|
||||
DaytonaSandbox connected to the sandbox.
|
||||
"""
|
||||
return await asyncio.to_thread(_find_or_create, str(thread_id))
|
||||
|
||||
|
||||
async def delete_sandbox(thread_id: int | str) -> None:
|
||||
"""Delete the sandbox for a conversation thread."""
|
||||
|
||||
def _delete() -> None:
|
||||
client = _get_client()
|
||||
labels = {THREAD_LABEL_KEY: str(thread_id)}
|
||||
try:
|
||||
sandbox = client.find_one(labels=labels)
|
||||
except DaytonaError:
|
||||
logger.debug(
|
||||
"No sandbox to delete for thread %s (already removed)", thread_id
|
||||
)
|
||||
return
|
||||
try:
|
||||
client.delete(sandbox)
|
||||
logger.info("Sandbox deleted: %s", sandbox.id)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to delete sandbox for thread %s",
|
||||
thread_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
await asyncio.to_thread(_delete)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Local file persistence
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _get_sandbox_files_dir() -> Path:
|
||||
return Path(os.environ.get("SANDBOX_FILES_DIR", "sandbox_files"))
|
||||
|
||||
|
||||
def _local_path_for(thread_id: int | str, sandbox_path: str) -> Path:
|
||||
"""Map a sandbox-internal absolute path to a local filesystem path."""
|
||||
relative = sandbox_path.lstrip("/")
|
||||
return _get_sandbox_files_dir() / str(thread_id) / relative
|
||||
|
||||
|
||||
def get_local_sandbox_file(thread_id: int | str, sandbox_path: str) -> bytes | None:
|
||||
"""Read a previously-persisted sandbox file from local storage.
|
||||
|
||||
Returns the file bytes, or *None* if the file does not exist locally.
|
||||
"""
|
||||
local = _local_path_for(thread_id, sandbox_path)
|
||||
if local.is_file():
|
||||
return local.read_bytes()
|
||||
return None
|
||||
|
||||
|
||||
def delete_local_sandbox_files(thread_id: int | str) -> None:
|
||||
"""Remove all locally-persisted sandbox files for a thread."""
|
||||
thread_dir = _get_sandbox_files_dir() / str(thread_id)
|
||||
if thread_dir.is_dir():
|
||||
shutil.rmtree(thread_dir, ignore_errors=True)
|
||||
logger.info("Deleted local sandbox files for thread %s", thread_id)
|
||||
|
||||
|
||||
async def persist_and_delete_sandbox(
|
||||
thread_id: int | str,
|
||||
sandbox_file_paths: list[str],
|
||||
) -> None:
|
||||
"""Download sandbox files to local storage, then delete the sandbox.
|
||||
|
||||
Each file in *sandbox_file_paths* is downloaded from the Daytona
|
||||
sandbox and saved under ``{SANDBOX_FILES_DIR}/{thread_id}/…``.
|
||||
Per-file errors are logged but do **not** prevent the sandbox from
|
||||
being deleted — freeing Daytona storage is the priority.
|
||||
"""
|
||||
|
||||
def _persist_and_delete() -> None:
|
||||
client = _get_client()
|
||||
labels = {THREAD_LABEL_KEY: str(thread_id)}
|
||||
|
||||
try:
|
||||
sandbox = client.find_one(labels=labels)
|
||||
except Exception:
|
||||
logger.info(
|
||||
"No sandbox found for thread %s — nothing to persist", thread_id
|
||||
)
|
||||
return
|
||||
|
||||
# Ensure the sandbox is running so we can download files
|
||||
if sandbox.state != SandboxState.STARTED:
|
||||
try:
|
||||
sandbox.start(timeout=60)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Could not start sandbox %s for file download — deleting anyway",
|
||||
sandbox.id,
|
||||
exc_info=True,
|
||||
)
|
||||
try:
|
||||
client.delete(sandbox)
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
|
||||
for path in sandbox_file_paths:
|
||||
try:
|
||||
content: bytes = sandbox.fs.download_file(path)
|
||||
local = _local_path_for(thread_id, path)
|
||||
local.parent.mkdir(parents=True, exist_ok=True)
|
||||
local.write_bytes(content)
|
||||
logger.info("Persisted sandbox file %s → %s", path, local)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist sandbox file %s for thread %s",
|
||||
path,
|
||||
thread_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
try:
|
||||
client.delete(sandbox)
|
||||
logger.info("Sandbox deleted after file persistence: %s", sandbox.id)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to delete sandbox %s after persistence",
|
||||
sandbox.id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
await asyncio.to_thread(_persist_and_delete)
|
||||
|
|
@ -645,6 +645,87 @@ However, from your video learning, it's important to note that asyncio is not su
|
|||
</citation_instructions>
|
||||
"""
|
||||
|
||||
# Sandbox / code execution instructions — appended when sandbox backend is enabled.
|
||||
# Inspired by Claude's computer-use prompt, scoped to code execution & data analytics.
|
||||
SANDBOX_EXECUTION_INSTRUCTIONS = """
|
||||
<code_execution>
|
||||
You have access to a secure, isolated Linux sandbox environment for running code and shell commands.
|
||||
This gives you the `execute` tool alongside the standard filesystem tools (`ls`, `read_file`, `write_file`, `edit_file`, `glob`, `grep`).
|
||||
|
||||
## CRITICAL — CODE-FIRST RULE
|
||||
|
||||
ALWAYS prefer executing code over giving a text-only response when the user's request involves ANY of the following:
|
||||
- **Creating a chart, plot, graph, or visualization** → Write Python code and generate the actual file. NEVER describe percentages or data in text and offer to "paste into Excel". Just produce the chart.
|
||||
- **Data analysis, statistics, or computation** → Write code to compute the answer. Do not do math by hand in text.
|
||||
- **Generating or transforming files** (CSV, PDF, images, etc.) → Write code to create the file.
|
||||
- **Running, testing, or debugging code** → Execute it in the sandbox.
|
||||
|
||||
This applies even when you first retrieve data from the knowledge base. After `search_knowledge_base` returns relevant data, **immediately proceed to write and execute code** if the user's request matches any of the categories above. Do NOT stop at a text summary and wait for the user to ask you to "use Python" — that extra round-trip is a poor experience.
|
||||
|
||||
Example (CORRECT):
|
||||
User: "Create a pie chart of my benefits"
|
||||
→ 1. search_knowledge_base → retrieve benefits data
|
||||
→ 2. Immediately execute Python code (matplotlib) to generate the pie chart
|
||||
→ 3. Return the downloadable file + brief description
|
||||
|
||||
Example (WRONG):
|
||||
User: "Create a pie chart of my benefits"
|
||||
→ 1. search_knowledge_base → retrieve benefits data
|
||||
→ 2. Print a text table with percentages and ask the user if they want a chart ← NEVER do this
|
||||
|
||||
## When to Use Code Execution
|
||||
|
||||
Use the sandbox when the task benefits from actually running code rather than just describing it:
|
||||
- **Data analysis**: Load CSVs/JSON, compute statistics, filter/aggregate data, pivot tables
|
||||
- **Visualization**: Generate charts and plots (matplotlib, plotly, seaborn)
|
||||
- **Calculations**: Math, financial modeling, unit conversions, simulations
|
||||
- **Code validation**: Run and test code snippets the user provides or asks about
|
||||
- **File processing**: Parse, transform, or convert data files
|
||||
- **Quick prototyping**: Demonstrate working code for the user's problem
|
||||
- **Package exploration**: Install and test libraries the user is evaluating
|
||||
|
||||
## When NOT to Use Code Execution
|
||||
|
||||
Do not use the sandbox for:
|
||||
- Answering factual questions from your own knowledge
|
||||
- Summarizing or explaining concepts
|
||||
- Simple formatting or text generation tasks
|
||||
- Tasks that don't require running code to answer
|
||||
|
||||
## Package Management
|
||||
|
||||
- Use `pip install <package>` to install Python packages as needed
|
||||
- Common data/analytics packages (pandas, numpy, matplotlib, scipy, scikit-learn) may need to be installed on first use
|
||||
- Always verify a package installed successfully before using it
|
||||
|
||||
## Working Guidelines
|
||||
|
||||
- **Working directory**: The shell starts in the sandbox user's home directory (e.g. `/home/daytona`). Use **relative paths** or `/tmp/` for all files you create. NEVER write directly to `/home/` — that is the parent directory and is not writable. Use `pwd` if you need to discover the current working directory.
|
||||
- **Iterative approach**: For complex tasks, break work into steps — write code, run it, check output, refine
|
||||
- **Error handling**: If code fails, read the error, fix the issue, and retry. Don't just report the error without attempting a fix.
|
||||
- **Show results**: When generating plots or outputs, present the key findings directly in your response. For plots, save to a file and describe the results.
|
||||
- **Be efficient**: Install packages once per session. Combine related commands when possible.
|
||||
- **Large outputs**: If command output is very large, use `head`, `tail`, or save to a file and read selectively.
|
||||
|
||||
## Sharing Generated Files
|
||||
|
||||
When your code creates output files (images, CSVs, PDFs, etc.) in the sandbox:
|
||||
- **Print the absolute path** at the end of your script so the user can download the file. Example: `print("SANDBOX_FILE: /tmp/chart.png")`
|
||||
- **DO NOT call `display_image`** for files created inside the sandbox. Sandbox files are not accessible via public URLs, so `display_image` will always show "Image not available". The frontend automatically renders a download button from the `SANDBOX_FILE:` marker.
|
||||
- You can output multiple files, one per line: `print("SANDBOX_FILE: /tmp/report.csv")`, `print("SANDBOX_FILE: /tmp/chart.png")`
|
||||
- Always describe what the file contains in your response text so the user knows what they are downloading.
|
||||
- IMPORTANT: Every `execute` call that saves a file MUST print the `SANDBOX_FILE: <path>` marker. Without it the user cannot download the file.
|
||||
|
||||
## Data Analytics Best Practices
|
||||
|
||||
When the user asks you to analyze data:
|
||||
1. First, inspect the data structure (`head`, `shape`, `dtypes`, `describe()`)
|
||||
2. Clean and validate before computing (handle nulls, check types)
|
||||
3. Perform the analysis and present results clearly
|
||||
4. Offer follow-up insights or visualizations when appropriate
|
||||
</code_execution>
|
||||
"""
|
||||
|
||||
# Anti-citation prompt - used when citations are disabled
|
||||
# This explicitly tells the model NOT to include citations
|
||||
SURFSENSE_NO_CITATION_INSTRUCTIONS = """
|
||||
|
|
@ -670,6 +751,7 @@ Your goal is to provide helpful, informative answers in a clean, readable format
|
|||
def build_surfsense_system_prompt(
|
||||
today: datetime | None = None,
|
||||
thread_visibility: ChatVisibility | None = None,
|
||||
sandbox_enabled: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Build the SurfSense system prompt with default settings.
|
||||
|
|
@ -678,10 +760,12 @@ def build_surfsense_system_prompt(
|
|||
- Default system instructions
|
||||
- Tools instructions (always included)
|
||||
- Citation instructions enabled
|
||||
- Sandbox execution instructions (when sandbox_enabled=True)
|
||||
|
||||
Args:
|
||||
today: Optional datetime for today's date (defaults to current UTC date)
|
||||
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
|
||||
sandbox_enabled: Whether the sandbox backend is active (adds code execution instructions).
|
||||
|
||||
Returns:
|
||||
Complete system prompt string
|
||||
|
|
@ -691,7 +775,13 @@ def build_surfsense_system_prompt(
|
|||
system_instructions = _get_system_instructions(visibility, today)
|
||||
tools_instructions = _get_tools_instructions(visibility)
|
||||
citation_instructions = SURFSENSE_CITATION_INSTRUCTIONS
|
||||
return system_instructions + tools_instructions + citation_instructions
|
||||
sandbox_instructions = SANDBOX_EXECUTION_INSTRUCTIONS if sandbox_enabled else ""
|
||||
return (
|
||||
system_instructions
|
||||
+ tools_instructions
|
||||
+ citation_instructions
|
||||
+ sandbox_instructions
|
||||
)
|
||||
|
||||
|
||||
def build_configurable_system_prompt(
|
||||
|
|
@ -700,14 +790,16 @@ def build_configurable_system_prompt(
|
|||
citations_enabled: bool = True,
|
||||
today: datetime | None = None,
|
||||
thread_visibility: ChatVisibility | None = None,
|
||||
sandbox_enabled: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Build a configurable SurfSense system prompt based on NewLLMConfig settings.
|
||||
|
||||
The prompt is composed of three parts:
|
||||
The prompt is composed of up to four parts:
|
||||
1. System Instructions - either custom or default SURFSENSE_SYSTEM_INSTRUCTIONS
|
||||
2. Tools Instructions - always included (SURFSENSE_TOOLS_INSTRUCTIONS)
|
||||
3. Citation Instructions - either SURFSENSE_CITATION_INSTRUCTIONS or SURFSENSE_NO_CITATION_INSTRUCTIONS
|
||||
4. Sandbox Execution Instructions - when sandbox_enabled=True
|
||||
|
||||
Args:
|
||||
custom_system_instructions: Custom system instructions to use. If empty/None and
|
||||
|
|
@ -719,6 +811,7 @@ def build_configurable_system_prompt(
|
|||
anti-citation instructions (False).
|
||||
today: Optional datetime for today's date (defaults to current UTC date)
|
||||
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
|
||||
sandbox_enabled: Whether the sandbox backend is active (adds code execution instructions).
|
||||
|
||||
Returns:
|
||||
Complete system prompt string
|
||||
|
|
@ -727,7 +820,6 @@ def build_configurable_system_prompt(
|
|||
|
||||
# Determine system instructions
|
||||
if custom_system_instructions and custom_system_instructions.strip():
|
||||
# Use custom instructions, injecting the date placeholder if present
|
||||
system_instructions = custom_system_instructions.format(
|
||||
resolved_today=resolved_today
|
||||
)
|
||||
|
|
@ -735,7 +827,6 @@ def build_configurable_system_prompt(
|
|||
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||
system_instructions = _get_system_instructions(visibility, today)
|
||||
else:
|
||||
# No system instructions (edge case)
|
||||
system_instructions = ""
|
||||
|
||||
# Tools instructions: conditional on thread_visibility (private vs shared memory wording)
|
||||
|
|
@ -748,7 +839,14 @@ def build_configurable_system_prompt(
|
|||
else SURFSENSE_NO_CITATION_INSTRUCTIONS
|
||||
)
|
||||
|
||||
return system_instructions + tools_instructions + citation_instructions
|
||||
sandbox_instructions = SANDBOX_EXECUTION_INSTRUCTIONS if sandbox_enabled else ""
|
||||
|
||||
return (
|
||||
system_instructions
|
||||
+ tools_instructions
|
||||
+ citation_instructions
|
||||
+ sandbox_instructions
|
||||
)
|
||||
|
||||
|
||||
def get_default_system_instructions() -> str:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,11 @@
|
|||
from app.agents.new_chat.tools.google_drive.create_file import (
|
||||
create_create_google_drive_file_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.google_drive.trash_file import (
|
||||
create_delete_google_drive_file_tool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"create_create_google_drive_file_tool",
|
||||
"create_delete_google_drive_file_tool",
|
||||
]
|
||||
|
|
@ -0,0 +1,239 @@
|
|||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from googleapiclient.errors import HttpError
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.connectors.google_drive.client import GoogleDriveClient
|
||||
from app.connectors.google_drive.file_types import GOOGLE_DOC, GOOGLE_SHEET
|
||||
from app.services.google_drive import GoogleDriveToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MIME_MAP: dict[str, str] = {
|
||||
"google_doc": GOOGLE_DOC,
|
||||
"google_sheet": GOOGLE_SHEET,
|
||||
}
|
||||
|
||||
|
||||
def create_create_google_drive_file_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def create_google_drive_file(
|
||||
name: str,
|
||||
file_type: Literal["google_doc", "google_sheet"],
|
||||
content: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new Google Doc or Google Sheet in Google Drive.
|
||||
|
||||
Use this tool when the user explicitly asks to create a new document
|
||||
or spreadsheet in Google Drive.
|
||||
|
||||
Args:
|
||||
name: The file name (without extension).
|
||||
file_type: Either "google_doc" or "google_sheet".
|
||||
content: Optional initial content. For google_doc, provide markdown text.
|
||||
For google_sheet, provide CSV-formatted text.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: "success", "rejected", or "error"
|
||||
- file_id: Google Drive file ID (if success)
|
||||
- name: File name (if success)
|
||||
- web_view_link: URL to open the file (if success)
|
||||
- message: Result message
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined the action.
|
||||
Respond with a brief acknowledgment and do NOT retry or suggest alternatives.
|
||||
- If status is "insufficient_permissions", the connector lacks the required OAuth scope.
|
||||
Inform the user they need to re-authenticate and do NOT retry the action.
|
||||
|
||||
Examples:
|
||||
- "Create a Google Doc called 'Meeting Notes'"
|
||||
- "Create a spreadsheet named 'Budget 2026' with some sample data"
|
||||
"""
|
||||
logger.info(
|
||||
f"create_google_drive_file called: name='{name}', type='{file_type}'"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Google Drive tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
if file_type not in _MIME_MAP:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Unsupported file type '{file_type}'. Use 'google_doc' or 'google_sheet'.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = GoogleDriveToolMetadataService(db_session)
|
||||
context = await metadata_service.get_creation_context(
|
||||
search_space_id, user_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for creating Google Drive file: name='{name}', type='{file_type}'"
|
||||
)
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "google_drive_file_creation",
|
||||
"action": {
|
||||
"tool": "create_google_drive_file",
|
||||
"params": {
|
||||
"name": name,
|
||||
"file_type": file_type,
|
||||
"content": content,
|
||||
"connector_id": None,
|
||||
"parent_folder_id": None,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
logger.warning("No approval decision received")
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The file was not created. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_name = final_params.get("name", name)
|
||||
final_file_type = final_params.get("file_type", file_type)
|
||||
final_content = final_params.get("content", content)
|
||||
final_connector_id = final_params.get("connector_id")
|
||||
final_parent_folder_id = final_params.get("parent_folder_id")
|
||||
|
||||
if not final_name or not final_name.strip():
|
||||
return {"status": "error", "message": "File name cannot be empty."}
|
||||
|
||||
mime_type = _MIME_MAP.get(final_file_type)
|
||||
if not mime_type:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Unsupported file type '{final_file_type}'.",
|
||||
}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
if final_connector_id is not None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Google Drive connector is invalid or has been disconnected.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
else:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Google Drive connector found. Please connect Google Drive in your workspace settings.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
|
||||
logger.info(
|
||||
f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}"
|
||||
)
|
||||
client = GoogleDriveClient(
|
||||
session=db_session, connector_id=actual_connector_id
|
||||
)
|
||||
try:
|
||||
created = await client.create_file(
|
||||
name=final_name,
|
||||
mime_type=mime_type,
|
||||
parent_folder_id=final_parent_folder_id,
|
||||
content=final_content,
|
||||
)
|
||||
except HttpError as http_err:
|
||||
if http_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {http_err}"
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Drive account needs additional permissions. Please re-authenticate.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Google Drive file created: id={created.get('id')}, name={created.get('name')}"
|
||||
)
|
||||
return {
|
||||
"status": "success",
|
||||
"file_id": created.get("id"),
|
||||
"name": created.get("name"),
|
||||
"web_view_link": created.get("webViewLink"),
|
||||
"message": f"Successfully created '{created.get('name')}' in Google Drive.",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
|
||||
logger.error(f"Error creating Google Drive file: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while creating the file. Please try again.",
|
||||
}
|
||||
|
||||
return create_google_drive_file
|
||||
|
|
@ -0,0 +1,243 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from googleapiclient.errors import HttpError
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.connectors.google_drive.client import GoogleDriveClient
|
||||
from app.services.google_drive import GoogleDriveToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_delete_google_drive_file_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def delete_google_drive_file(
|
||||
file_name: str,
|
||||
delete_from_kb: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Move a Google Drive file to trash.
|
||||
|
||||
Use this tool when the user explicitly asks to delete, remove, or trash
|
||||
a file in Google Drive.
|
||||
|
||||
Args:
|
||||
file_name: The exact name of the file to trash (as it appears in Drive).
|
||||
delete_from_kb: Whether to also remove the file from the knowledge base.
|
||||
Default is False.
|
||||
Set to True to remove from both Google Drive and knowledge base.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: "success", "rejected", "not_found", or "error"
|
||||
- file_id: Google Drive file ID (if success)
|
||||
- deleted_from_kb: whether the document was removed from the knowledge base
|
||||
- message: Result message
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined. Respond with a brief
|
||||
acknowledgment and do NOT retry or suggest alternatives.
|
||||
- If status is "not_found", relay the exact message to the user and ask them
|
||||
to verify the file name or check if it has been indexed.
|
||||
- If status is "insufficient_permissions", the connector lacks the required OAuth scope.
|
||||
Inform the user they need to re-authenticate and do NOT retry this tool.
|
||||
|
||||
Examples:
|
||||
- "Delete the 'Meeting Notes' file from Google Drive"
|
||||
- "Trash the 'Old Budget' spreadsheet"
|
||||
"""
|
||||
logger.info(
|
||||
f"delete_google_drive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Google Drive tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = GoogleDriveToolMetadataService(db_session)
|
||||
context = await metadata_service.get_trash_context(
|
||||
search_space_id, user_id, file_name
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"File not found: {error_msg}")
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
logger.error(f"Failed to fetch trash context: {error_msg}")
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
file = context["file"]
|
||||
file_id = file["file_id"]
|
||||
document_id = file.get("document_id")
|
||||
connector_id_from_context = context["account"]["id"]
|
||||
|
||||
if not file_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "File ID is missing from the indexed document. Please re-index the file and try again.",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for deleting Google Drive file: '{file_name}' (file_id={file_id}, delete_from_kb={delete_from_kb})"
|
||||
)
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "google_drive_file_trash",
|
||||
"action": {
|
||||
"tool": "delete_google_drive_file",
|
||||
"params": {
|
||||
"file_id": file_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
logger.warning("No approval decision received")
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The file was not trashed. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
edited_action = decision.get("edited_action")
|
||||
final_params: dict[str, Any] = {}
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_file_id = final_params.get("file_id", file_id)
|
||||
final_connector_id = final_params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_delete_from_kb = final_params.get("delete_from_kb", delete_from_kb)
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this file.",
|
||||
}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Google Drive connector is invalid or has been disconnected.",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Deleting Google Drive file: file_id='{final_file_id}', connector={final_connector_id}"
|
||||
)
|
||||
client = GoogleDriveClient(session=db_session, connector_id=connector.id)
|
||||
try:
|
||||
await client.trash_file(file_id=final_file_id)
|
||||
except HttpError as http_err:
|
||||
if http_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {connector.id}: {http_err}"
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": connector.id,
|
||||
"message": "This Google Drive account needs additional permissions. Please re-authenticate.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Google Drive file deleted (moved to trash): file_id={final_file_id}"
|
||||
)
|
||||
|
||||
trash_result: dict[str, Any] = {
|
||||
"status": "success",
|
||||
"file_id": final_file_id,
|
||||
"message": f"Successfully moved '{file['name']}' to trash.",
|
||||
}
|
||||
|
||||
deleted_from_kb = False
|
||||
if final_delete_from_kb and document_id:
|
||||
try:
|
||||
from app.db import Document
|
||||
|
||||
doc_result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
if document:
|
||||
await db_session.delete(document)
|
||||
await db_session.commit()
|
||||
deleted_from_kb = True
|
||||
logger.info(
|
||||
f"Deleted document {document_id} from knowledge base"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Document {document_id} not found in KB")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document from KB: {e}")
|
||||
await db_session.rollback()
|
||||
trash_result["warning"] = (
|
||||
f"File moved to trash, but failed to remove from knowledge base: {e!s}"
|
||||
)
|
||||
|
||||
trash_result["deleted_from_kb"] = deleted_from_kb
|
||||
if deleted_from_kb:
|
||||
trash_result["message"] = (
|
||||
f"{trash_result.get('message', '')} (also removed from knowledge base)"
|
||||
)
|
||||
|
||||
return trash_result
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
|
||||
logger.error(f"Error deleting Google Drive file: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while trashing the file. Please try again.",
|
||||
}
|
||||
|
||||
return delete_google_drive_file
|
||||
|
|
@ -47,6 +47,10 @@ from app.db import ChatVisibility
|
|||
|
||||
from .display_image import create_display_image_tool
|
||||
from .generate_image import create_generate_image_tool
|
||||
from .google_drive import (
|
||||
create_create_google_drive_file_tool,
|
||||
create_delete_google_drive_file_tool,
|
||||
)
|
||||
from .knowledge_base import create_search_knowledge_base_tool
|
||||
from .linear import (
|
||||
create_create_linear_issue_tool,
|
||||
|
|
@ -292,6 +296,29 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
# =========================================================================
|
||||
# GOOGLE DRIVE TOOLS - create files, delete files
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="create_google_drive_file",
|
||||
description="Create a new Google Doc or Google Sheet in Google Drive",
|
||||
factory=lambda deps: create_create_google_drive_file_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_google_drive_file",
|
||||
description="Move an indexed Google Drive file to trash",
|
||||
factory=lambda deps: create_delete_google_drive_file_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,15 @@
|
|||
"""Google Drive API client."""
|
||||
|
||||
import io
|
||||
from typing import Any
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.errors import HttpError
|
||||
from googleapiclient.http import MediaIoBaseUpload
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from .credentials import get_valid_credentials
|
||||
from .file_types import GOOGLE_DOC, GOOGLE_SHEET
|
||||
|
||||
|
||||
class GoogleDriveClient:
|
||||
|
|
@ -179,3 +182,65 @@ class GoogleDriveClient:
|
|||
return None, f"HTTP error exporting file: {e.resp.status}"
|
||||
except Exception as e:
|
||||
return None, f"Error exporting file: {e!s}"
|
||||
|
||||
async def create_file(
|
||||
self,
|
||||
name: str,
|
||||
mime_type: str,
|
||||
parent_folder_id: str | None = None,
|
||||
content: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
service = await self.get_service()
|
||||
|
||||
body: dict[str, Any] = {"name": name, "mimeType": mime_type}
|
||||
if parent_folder_id:
|
||||
body["parents"] = [parent_folder_id]
|
||||
|
||||
media: MediaIoBaseUpload | None = None
|
||||
if content:
|
||||
if mime_type == GOOGLE_DOC:
|
||||
import markdown as md_lib
|
||||
|
||||
html = md_lib.markdown(content)
|
||||
media = MediaIoBaseUpload(
|
||||
io.BytesIO(html.encode("utf-8")),
|
||||
mimetype="text/html",
|
||||
resumable=False,
|
||||
)
|
||||
elif mime_type == GOOGLE_SHEET:
|
||||
media = MediaIoBaseUpload(
|
||||
io.BytesIO(content.encode("utf-8")),
|
||||
mimetype="text/csv",
|
||||
resumable=False,
|
||||
)
|
||||
|
||||
if media:
|
||||
return (
|
||||
service.files()
|
||||
.create(
|
||||
body=body,
|
||||
media_body=media,
|
||||
fields="id,name,mimeType,webViewLink",
|
||||
supportsAllDrives=True,
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
|
||||
return (
|
||||
service.files()
|
||||
.create(
|
||||
body=body,
|
||||
fields="id,name,mimeType,webViewLink",
|
||||
supportsAllDrives=True,
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
|
||||
async def trash_file(self, file_id: str) -> bool:
|
||||
service = await self.get_service()
|
||||
service.files().update(
|
||||
fileId=file_id,
|
||||
body={"trashed": True},
|
||||
supportsAllDrives=True,
|
||||
).execute()
|
||||
return True
|
||||
|
|
|
|||
0
surfsense_backend/app/indexing_pipeline/__init__.py
Normal file
0
surfsense_backend/app/indexing_pipeline/__init__.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import DocumentStatus, DocumentType
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
|
||||
|
||||
async def index_uploaded_file(
|
||||
markdown_content: str,
|
||||
filename: str,
|
||||
etl_service: str,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
session: AsyncSession,
|
||||
llm,
|
||||
) -> None:
|
||||
connector_doc = ConnectorDocument(
|
||||
title=filename,
|
||||
source_markdown=markdown_content,
|
||||
unique_id=filename,
|
||||
document_type=DocumentType.FILE,
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=user_id,
|
||||
connector_id=None,
|
||||
should_summarize=True,
|
||||
should_use_code_chunker=False,
|
||||
fallback_summary=markdown_content[:4000],
|
||||
metadata={
|
||||
"FILE_NAME": filename,
|
||||
"ETL_SERVICE": etl_service,
|
||||
},
|
||||
)
|
||||
|
||||
service = IndexingPipelineService(session)
|
||||
documents = await service.prepare_for_indexing([connector_doc])
|
||||
|
||||
if not documents:
|
||||
raise RuntimeError("prepare_for_indexing returned no documents")
|
||||
|
||||
indexed = await service.index(documents[0], connector_doc, llm)
|
||||
|
||||
if not DocumentStatus.is_state(indexed.status, DocumentStatus.READY):
|
||||
raise RuntimeError(indexed.status.get("reason", "Indexing failed"))
|
||||
|
||||
indexed.content_needs_reindexing = False
|
||||
await session.commit()
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.db import DocumentType
|
||||
|
||||
|
||||
class ConnectorDocument(BaseModel):
|
||||
"""Canonical data transfer object produced by connector adapters and consumed by the indexing pipeline."""
|
||||
|
||||
title: str
|
||||
source_markdown: str
|
||||
unique_id: str
|
||||
document_type: DocumentType
|
||||
search_space_id: int = Field(gt=0)
|
||||
should_summarize: bool = True
|
||||
should_use_code_chunker: bool = False
|
||||
fallback_summary: str | None = None
|
||||
metadata: dict = {}
|
||||
connector_id: int | None = None
|
||||
created_by_id: str
|
||||
|
||||
@field_validator("title", "source_markdown", "unique_id", "created_by_id")
|
||||
@classmethod
|
||||
def not_empty(cls, v: str, info) -> str:
|
||||
if not v.strip():
|
||||
raise ValueError(f"{info.field_name} must not be empty or whitespace")
|
||||
return v
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
from app.config import config
|
||||
|
||||
|
||||
def chunk_text(text: str, use_code_chunker: bool = False) -> list[str]:
|
||||
"""Chunk a text string using the configured chunker and return the chunk texts."""
|
||||
chunker = (
|
||||
config.code_chunker_instance if use_code_chunker else config.chunker_instance
|
||||
)
|
||||
return [c.text for c in chunker.chunk(text)]
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
from app.config import config
|
||||
|
||||
|
||||
def embed_text(text: str) -> list[float]:
|
||||
"""Embed a single text string using the configured embedding model."""
|
||||
return config.embedding_model_instance.embed(text)
|
||||
15
surfsense_backend/app/indexing_pipeline/document_hashing.py
Normal file
15
surfsense_backend/app/indexing_pipeline/document_hashing.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
import hashlib
|
||||
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
|
||||
|
||||
def compute_unique_identifier_hash(doc: ConnectorDocument) -> str:
|
||||
"""Return a stable SHA-256 hash identifying a document by its source identity."""
|
||||
combined = f"{doc.document_type.value}:{doc.unique_id}:{doc.search_space_id}"
|
||||
return hashlib.sha256(combined.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def compute_content_hash(doc: ConnectorDocument) -> str:
|
||||
"""Return a SHA-256 hash of the document's content scoped to its search space."""
|
||||
combined = f"{doc.search_space_id}:{doc.source_markdown}"
|
||||
return hashlib.sha256(combined.encode("utf-8")).hexdigest()
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import object_session
|
||||
from sqlalchemy.orm.attributes import set_committed_value
|
||||
|
||||
from app.db import Document, DocumentStatus
|
||||
|
||||
|
||||
async def rollback_and_persist_failure(
|
||||
session: AsyncSession, document: Document, message: str
|
||||
) -> None:
|
||||
"""Roll back the current transaction and best-effort persist a failed status.
|
||||
|
||||
Called exclusively from except blocks — must never raise, or the new exception
|
||||
would chain with the original and mask it entirely.
|
||||
"""
|
||||
try:
|
||||
await session.rollback()
|
||||
except Exception:
|
||||
return # Session is completely dead; nothing further we can do.
|
||||
try:
|
||||
await session.refresh(document)
|
||||
document.updated_at = datetime.now(UTC)
|
||||
document.status = DocumentStatus.failed(message)
|
||||
await session.commit()
|
||||
except Exception:
|
||||
pass # Best-effort; document will be retried on the next sync.
|
||||
|
||||
|
||||
def attach_chunks_to_document(document: Document, chunks: list) -> None:
|
||||
"""Assign chunks to a document without triggering SQLAlchemy async lazy loading."""
|
||||
set_committed_value(document, "chunks", chunks)
|
||||
session = object_session(document)
|
||||
if session is not None:
|
||||
if document.id is not None:
|
||||
for chunk in chunks:
|
||||
chunk.document_id = document.id
|
||||
session.add_all(chunks)
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
from app.prompts import SUMMARY_PROMPT_TEMPLATE
|
||||
from app.utils.document_converters import optimize_content_for_context_window
|
||||
|
||||
|
||||
async def summarize_document(
|
||||
source_markdown: str, llm, metadata: dict | None = None
|
||||
) -> str:
|
||||
"""Generate a text summary of a document using an LLM, prefixed with metadata when provided."""
|
||||
model_name = getattr(llm, "model", "gpt-3.5-turbo")
|
||||
optimized_content = optimize_content_for_context_window(
|
||||
source_markdown, metadata, model_name
|
||||
)
|
||||
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | llm
|
||||
content_with_metadata = (
|
||||
f"<DOCUMENT><DOCUMENT_METADATA>\n\n{metadata}\n\n</DOCUMENT_METADATA>"
|
||||
f"\n\n<DOCUMENT_CONTENT>\n\n{optimized_content}\n\n</DOCUMENT_CONTENT></DOCUMENT>"
|
||||
)
|
||||
summary_result = await summary_chain.ainvoke({"document": content_with_metadata})
|
||||
summary_content = summary_result.content
|
||||
|
||||
if metadata:
|
||||
metadata_parts = ["# DOCUMENT METADATA"]
|
||||
for key, value in metadata.items():
|
||||
if value:
|
||||
metadata_parts.append(f"**{key.replace('_', ' ').title()}:** {value}")
|
||||
metadata_section = "\n".join(metadata_parts)
|
||||
return f"{metadata_section}\n\n# DOCUMENT SUMMARY\n\n{summary_content}"
|
||||
|
||||
return summary_content
|
||||
124
surfsense_backend/app/indexing_pipeline/exceptions.py
Normal file
124
surfsense_backend/app/indexing_pipeline/exceptions.py
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
from litellm.exceptions import (
|
||||
APIConnectionError,
|
||||
APIResponseValidationError,
|
||||
AuthenticationError,
|
||||
BadGatewayError,
|
||||
BadRequestError,
|
||||
InternalServerError,
|
||||
NotFoundError,
|
||||
PermissionDeniedError,
|
||||
RateLimitError,
|
||||
ServiceUnavailableError,
|
||||
Timeout,
|
||||
UnprocessableEntityError,
|
||||
)
|
||||
|
||||
# Tuples for use directly in except clauses.
|
||||
RETRYABLE_LLM_ERRORS = (
|
||||
RateLimitError,
|
||||
Timeout,
|
||||
ServiceUnavailableError,
|
||||
BadGatewayError,
|
||||
InternalServerError,
|
||||
APIConnectionError,
|
||||
)
|
||||
|
||||
PERMANENT_LLM_ERRORS = (
|
||||
AuthenticationError,
|
||||
PermissionDeniedError,
|
||||
NotFoundError,
|
||||
BadRequestError,
|
||||
UnprocessableEntityError,
|
||||
APIResponseValidationError,
|
||||
)
|
||||
|
||||
# (LiteLLMEmbeddings, CohereEmbeddings, GeminiEmbeddings all normalize to RuntimeError).
|
||||
EMBEDDING_ERRORS = (
|
||||
RuntimeError, # local device failure or API backend normalization
|
||||
OSError, # model files missing or corrupted (local backends)
|
||||
MemoryError, # document too large for available RAM
|
||||
)
|
||||
|
||||
|
||||
class PipelineMessages:
|
||||
RATE_LIMIT = "LLM rate limit exceeded. Will retry on next sync."
|
||||
LLM_TIMEOUT = "LLM request timed out. Will retry on next sync."
|
||||
LLM_UNAVAILABLE = "LLM service temporarily unavailable. Will retry on next sync."
|
||||
LLM_BAD_GATEWAY = "LLM gateway error. Will retry on next sync."
|
||||
LLM_SERVER_ERROR = "LLM internal server error. Will retry on next sync."
|
||||
LLM_CONNECTION = "Could not reach the LLM service. Check network connectivity."
|
||||
|
||||
LLM_AUTH = "LLM authentication failed. Check your API key."
|
||||
LLM_PERMISSION = "LLM request denied. Check your account permissions."
|
||||
LLM_NOT_FOUND = "LLM model not found. Check your model configuration."
|
||||
LLM_BAD_REQUEST = "LLM rejected the request. Document content may be invalid."
|
||||
LLM_UNPROCESSABLE = (
|
||||
"Document exceeds the LLM context window even after optimization."
|
||||
)
|
||||
LLM_RESPONSE = "LLM returned an invalid response."
|
||||
|
||||
EMBEDDING_FAILED = (
|
||||
"Embedding failed. Check your embedding model configuration or service."
|
||||
)
|
||||
EMBEDDING_MODEL = "Embedding model files are missing or corrupted."
|
||||
EMBEDDING_MEMORY = "Not enough memory to embed this document."
|
||||
|
||||
CHUNKING_OVERFLOW = "Document structure is too deeply nested to chunk."
|
||||
|
||||
|
||||
def safe_exception_message(exc: Exception) -> str:
|
||||
try:
|
||||
return str(exc)
|
||||
except Exception:
|
||||
return "Something went wrong during indexing. Error details could not be retrieved."
|
||||
|
||||
|
||||
def llm_retryable_message(exc: Exception) -> str:
|
||||
try:
|
||||
if isinstance(exc, RateLimitError):
|
||||
return PipelineMessages.RATE_LIMIT
|
||||
if isinstance(exc, Timeout):
|
||||
return PipelineMessages.LLM_TIMEOUT
|
||||
if isinstance(exc, ServiceUnavailableError):
|
||||
return PipelineMessages.LLM_UNAVAILABLE
|
||||
if isinstance(exc, BadGatewayError):
|
||||
return PipelineMessages.LLM_BAD_GATEWAY
|
||||
if isinstance(exc, InternalServerError):
|
||||
return PipelineMessages.LLM_SERVER_ERROR
|
||||
if isinstance(exc, APIConnectionError):
|
||||
return PipelineMessages.LLM_CONNECTION
|
||||
return safe_exception_message(exc)
|
||||
except Exception:
|
||||
return "Something went wrong when calling the LLM."
|
||||
|
||||
|
||||
def llm_permanent_message(exc: Exception) -> str:
|
||||
try:
|
||||
if isinstance(exc, AuthenticationError):
|
||||
return PipelineMessages.LLM_AUTH
|
||||
if isinstance(exc, PermissionDeniedError):
|
||||
return PipelineMessages.LLM_PERMISSION
|
||||
if isinstance(exc, NotFoundError):
|
||||
return PipelineMessages.LLM_NOT_FOUND
|
||||
if isinstance(exc, BadRequestError):
|
||||
return PipelineMessages.LLM_BAD_REQUEST
|
||||
if isinstance(exc, UnprocessableEntityError):
|
||||
return PipelineMessages.LLM_UNPROCESSABLE
|
||||
if isinstance(exc, APIResponseValidationError):
|
||||
return PipelineMessages.LLM_RESPONSE
|
||||
return safe_exception_message(exc)
|
||||
except Exception:
|
||||
return "Something went wrong when calling the LLM."
|
||||
|
||||
|
||||
def embedding_message(exc: Exception) -> str:
|
||||
try:
|
||||
if isinstance(exc, RuntimeError):
|
||||
return PipelineMessages.EMBEDDING_FAILED
|
||||
if isinstance(exc, OSError):
|
||||
return PipelineMessages.EMBEDDING_MODEL
|
||||
if isinstance(exc, MemoryError):
|
||||
return PipelineMessages.EMBEDDING_MEMORY
|
||||
return safe_exception_message(exc)
|
||||
except Exception:
|
||||
return "Something went wrong when generating the embedding."
|
||||
|
|
@ -0,0 +1,237 @@
|
|||
import contextlib
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import Chunk, Document, DocumentStatus
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
from app.indexing_pipeline.document_chunker import chunk_text
|
||||
from app.indexing_pipeline.document_embedder import embed_text
|
||||
from app.indexing_pipeline.document_hashing import (
|
||||
compute_content_hash,
|
||||
compute_unique_identifier_hash,
|
||||
)
|
||||
from app.indexing_pipeline.document_persistence import (
|
||||
attach_chunks_to_document,
|
||||
rollback_and_persist_failure,
|
||||
)
|
||||
from app.indexing_pipeline.document_summarizer import summarize_document
|
||||
from app.indexing_pipeline.exceptions import (
|
||||
EMBEDDING_ERRORS,
|
||||
PERMANENT_LLM_ERRORS,
|
||||
RETRYABLE_LLM_ERRORS,
|
||||
PipelineMessages,
|
||||
embedding_message,
|
||||
llm_permanent_message,
|
||||
llm_retryable_message,
|
||||
safe_exception_message,
|
||||
)
|
||||
from app.indexing_pipeline.pipeline_logger import (
|
||||
PipelineLogContext,
|
||||
log_batch_aborted,
|
||||
log_chunking_overflow,
|
||||
log_doc_skipped_unknown,
|
||||
log_document_queued,
|
||||
log_document_requeued,
|
||||
log_document_updated,
|
||||
log_embedding_error,
|
||||
log_index_started,
|
||||
log_index_success,
|
||||
log_permanent_llm_error,
|
||||
log_race_condition,
|
||||
log_retryable_llm_error,
|
||||
log_unexpected_error,
|
||||
)
|
||||
|
||||
|
||||
class IndexingPipelineService:
|
||||
"""Single pipeline for indexing connector documents. All connectors use this service."""
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self.session = session
|
||||
|
||||
async def prepare_for_indexing(
|
||||
self, connector_docs: list[ConnectorDocument]
|
||||
) -> list[Document]:
|
||||
"""
|
||||
Persist new documents and detect changes, returning only those that need indexing.
|
||||
"""
|
||||
documents = []
|
||||
seen_hashes: set[str] = set()
|
||||
batch_ctx = PipelineLogContext(
|
||||
connector_id=connector_docs[0].connector_id if connector_docs else 0,
|
||||
search_space_id=connector_docs[0].search_space_id if connector_docs else 0,
|
||||
unique_id="batch",
|
||||
)
|
||||
|
||||
for connector_doc in connector_docs:
|
||||
ctx = PipelineLogContext(
|
||||
connector_id=connector_doc.connector_id,
|
||||
search_space_id=connector_doc.search_space_id,
|
||||
unique_id=connector_doc.unique_id,
|
||||
)
|
||||
try:
|
||||
unique_identifier_hash = compute_unique_identifier_hash(connector_doc)
|
||||
content_hash = compute_content_hash(connector_doc)
|
||||
|
||||
if unique_identifier_hash in seen_hashes:
|
||||
continue
|
||||
seen_hashes.add(unique_identifier_hash)
|
||||
|
||||
result = await self.session.execute(
|
||||
select(Document).filter(
|
||||
Document.unique_identifier_hash == unique_identifier_hash
|
||||
)
|
||||
)
|
||||
existing = result.scalars().first()
|
||||
|
||||
if existing is not None:
|
||||
if existing.content_hash == content_hash:
|
||||
if existing.title != connector_doc.title:
|
||||
existing.title = connector_doc.title
|
||||
existing.updated_at = datetime.now(UTC)
|
||||
if not DocumentStatus.is_state(
|
||||
existing.status, DocumentStatus.READY
|
||||
):
|
||||
existing.status = DocumentStatus.pending()
|
||||
existing.updated_at = datetime.now(UTC)
|
||||
documents.append(existing)
|
||||
log_document_requeued(ctx)
|
||||
continue
|
||||
|
||||
existing.title = connector_doc.title
|
||||
existing.content_hash = content_hash
|
||||
existing.source_markdown = connector_doc.source_markdown
|
||||
existing.document_metadata = connector_doc.metadata
|
||||
existing.updated_at = datetime.now(UTC)
|
||||
existing.status = DocumentStatus.pending()
|
||||
documents.append(existing)
|
||||
log_document_updated(ctx)
|
||||
continue
|
||||
|
||||
duplicate = await self.session.execute(
|
||||
select(Document).filter(Document.content_hash == content_hash)
|
||||
)
|
||||
if duplicate.scalars().first() is not None:
|
||||
continue
|
||||
|
||||
document = Document(
|
||||
title=connector_doc.title,
|
||||
document_type=connector_doc.document_type,
|
||||
content="Pending...",
|
||||
content_hash=content_hash,
|
||||
unique_identifier_hash=unique_identifier_hash,
|
||||
source_markdown=connector_doc.source_markdown,
|
||||
document_metadata=connector_doc.metadata,
|
||||
search_space_id=connector_doc.search_space_id,
|
||||
connector_id=connector_doc.connector_id,
|
||||
created_by_id=connector_doc.created_by_id,
|
||||
updated_at=datetime.now(UTC),
|
||||
status=DocumentStatus.pending(),
|
||||
)
|
||||
self.session.add(document)
|
||||
documents.append(document)
|
||||
log_document_queued(ctx)
|
||||
|
||||
except Exception as e:
|
||||
log_doc_skipped_unknown(ctx, e)
|
||||
|
||||
try:
|
||||
await self.session.commit()
|
||||
return documents
|
||||
except IntegrityError:
|
||||
# A concurrent worker committed a document with the same content_hash
|
||||
# or unique_identifier_hash between our check and our INSERT.
|
||||
# The document already exists — roll back and let the next sync run handle it.
|
||||
log_race_condition(batch_ctx)
|
||||
await self.session.rollback()
|
||||
return []
|
||||
except Exception as e:
|
||||
log_batch_aborted(batch_ctx, e)
|
||||
await self.session.rollback()
|
||||
return []
|
||||
|
||||
async def index(
|
||||
self, document: Document, connector_doc: ConnectorDocument, llm
|
||||
) -> Document:
|
||||
"""
|
||||
Run summarization, embedding, and chunking for a document and persist the results.
|
||||
"""
|
||||
ctx = PipelineLogContext(
|
||||
connector_id=connector_doc.connector_id,
|
||||
search_space_id=connector_doc.search_space_id,
|
||||
unique_id=connector_doc.unique_id,
|
||||
doc_id=document.id,
|
||||
)
|
||||
try:
|
||||
log_index_started(ctx)
|
||||
document.status = DocumentStatus.processing()
|
||||
await self.session.commit()
|
||||
|
||||
if connector_doc.should_summarize and llm is not None:
|
||||
content = await summarize_document(
|
||||
connector_doc.source_markdown, llm, connector_doc.metadata
|
||||
)
|
||||
elif connector_doc.should_summarize and connector_doc.fallback_summary:
|
||||
content = connector_doc.fallback_summary
|
||||
else:
|
||||
content = connector_doc.source_markdown
|
||||
|
||||
embedding = embed_text(content)
|
||||
|
||||
await self.session.execute(
|
||||
delete(Chunk).where(Chunk.document_id == document.id)
|
||||
)
|
||||
|
||||
chunks = [
|
||||
Chunk(content=text, embedding=embed_text(text))
|
||||
for text in chunk_text(
|
||||
connector_doc.source_markdown,
|
||||
use_code_chunker=connector_doc.should_use_code_chunker,
|
||||
)
|
||||
]
|
||||
|
||||
document.content = content
|
||||
document.embedding = embedding
|
||||
attach_chunks_to_document(document, chunks)
|
||||
document.updated_at = datetime.now(UTC)
|
||||
document.status = DocumentStatus.ready()
|
||||
await self.session.commit()
|
||||
log_index_success(ctx, chunk_count=len(chunks))
|
||||
|
||||
except RETRYABLE_LLM_ERRORS as e:
|
||||
log_retryable_llm_error(ctx, e)
|
||||
await rollback_and_persist_failure(
|
||||
self.session, document, llm_retryable_message(e)
|
||||
)
|
||||
|
||||
except PERMANENT_LLM_ERRORS as e:
|
||||
log_permanent_llm_error(ctx, e)
|
||||
await rollback_and_persist_failure(
|
||||
self.session, document, llm_permanent_message(e)
|
||||
)
|
||||
|
||||
except RecursionError as e:
|
||||
log_chunking_overflow(ctx, e)
|
||||
await rollback_and_persist_failure(
|
||||
self.session, document, PipelineMessages.CHUNKING_OVERFLOW
|
||||
)
|
||||
|
||||
except EMBEDDING_ERRORS as e:
|
||||
log_embedding_error(ctx, e)
|
||||
await rollback_and_persist_failure(
|
||||
self.session, document, embedding_message(e)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
log_unexpected_error(ctx, e)
|
||||
await rollback_and_persist_failure(
|
||||
self.session, document, safe_exception_message(e)
|
||||
)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
await self.session.refresh(document)
|
||||
|
||||
return document
|
||||
126
surfsense_backend/app/indexing_pipeline/pipeline_logger.py
Normal file
126
surfsense_backend/app/indexing_pipeline/pipeline_logger.py
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineLogContext:
|
||||
connector_id: int | None
|
||||
search_space_id: int
|
||||
unique_id: str # always available from ConnectorDocument
|
||||
doc_id: int | None = None # set once the DB row exists (index phase only)
|
||||
|
||||
|
||||
class LogMessages:
|
||||
# prepare_for_indexing
|
||||
DOCUMENT_QUEUED = "New document queued for indexing."
|
||||
DOCUMENT_UPDATED = "Document content changed, re-queued for indexing."
|
||||
DOCUMENT_REQUEUED = "Stuck document re-queued for indexing."
|
||||
DOC_SKIPPED_UNKNOWN = "Unexpected error — document skipped."
|
||||
BATCH_ABORTED = "Fatal DB error — aborting prepare batch."
|
||||
RACE_CONDITION = "Concurrent worker beat us to the commit — rolling back batch."
|
||||
|
||||
# index
|
||||
INDEX_STARTED = "Document indexing started."
|
||||
INDEX_SUCCESS = "Document indexed successfully."
|
||||
LLM_RETRYABLE = (
|
||||
"Retryable LLM error — document marked failed, will retry on next sync."
|
||||
)
|
||||
LLM_PERMANENT = "Permanent LLM error — document marked failed."
|
||||
EMBEDDING_FAILED = "Embedding error — document marked failed."
|
||||
CHUNKING_OVERFLOW = "Chunking overflow — document marked failed."
|
||||
UNEXPECTED = "Unexpected error — document marked failed."
|
||||
|
||||
|
||||
def _format_context(ctx: PipelineLogContext) -> str:
|
||||
parts = [
|
||||
f"connector_id={ctx.connector_id}",
|
||||
f"search_space_id={ctx.search_space_id}",
|
||||
f"unique_id={ctx.unique_id}",
|
||||
]
|
||||
if ctx.doc_id is not None:
|
||||
parts.append(f"doc_id={ctx.doc_id}")
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
def _build_message(msg: str, ctx: PipelineLogContext, **extra) -> str:
|
||||
try:
|
||||
parts = [msg, _format_context(ctx)]
|
||||
for key, val in extra.items():
|
||||
parts.append(f"{key}={val}")
|
||||
return " ".join(parts)
|
||||
except Exception:
|
||||
return msg
|
||||
|
||||
|
||||
def _safe_log(
|
||||
level_fn, msg: str, ctx: PipelineLogContext, exc_info=None, **extra
|
||||
) -> None:
|
||||
# Logging must never raise — a broken log call inside an except block would
|
||||
# chain with the original exception and mask it entirely.
|
||||
try:
|
||||
message = _build_message(msg, ctx, **extra)
|
||||
level_fn(message, exc_info=exc_info)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# ── prepare_for_indexing ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def log_document_queued(ctx: PipelineLogContext) -> None:
|
||||
_safe_log(logger.info, LogMessages.DOCUMENT_QUEUED, ctx)
|
||||
|
||||
|
||||
def log_document_updated(ctx: PipelineLogContext) -> None:
|
||||
_safe_log(logger.info, LogMessages.DOCUMENT_UPDATED, ctx)
|
||||
|
||||
|
||||
def log_document_requeued(ctx: PipelineLogContext) -> None:
|
||||
_safe_log(logger.info, LogMessages.DOCUMENT_REQUEUED, ctx)
|
||||
|
||||
|
||||
def log_doc_skipped_unknown(ctx: PipelineLogContext, exc: Exception) -> None:
|
||||
_safe_log(
|
||||
logger.warning, LogMessages.DOC_SKIPPED_UNKNOWN, ctx, exc_info=exc, error=exc
|
||||
)
|
||||
|
||||
|
||||
def log_race_condition(ctx: PipelineLogContext) -> None:
|
||||
_safe_log(logger.warning, LogMessages.RACE_CONDITION, ctx)
|
||||
|
||||
|
||||
def log_batch_aborted(ctx: PipelineLogContext, exc: Exception) -> None:
|
||||
_safe_log(logger.error, LogMessages.BATCH_ABORTED, ctx, exc_info=exc, error=exc)
|
||||
|
||||
|
||||
# ── index ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def log_index_started(ctx: PipelineLogContext) -> None:
|
||||
_safe_log(logger.info, LogMessages.INDEX_STARTED, ctx)
|
||||
|
||||
|
||||
def log_index_success(ctx: PipelineLogContext, chunk_count: int) -> None:
|
||||
_safe_log(logger.info, LogMessages.INDEX_SUCCESS, ctx, chunk_count=chunk_count)
|
||||
|
||||
|
||||
def log_retryable_llm_error(ctx: PipelineLogContext, exc: Exception) -> None:
|
||||
_safe_log(logger.warning, LogMessages.LLM_RETRYABLE, ctx, exc_info=exc, error=exc)
|
||||
|
||||
|
||||
def log_permanent_llm_error(ctx: PipelineLogContext, exc: Exception) -> None:
|
||||
_safe_log(logger.error, LogMessages.LLM_PERMANENT, ctx, exc_info=exc, error=exc)
|
||||
|
||||
|
||||
def log_embedding_error(ctx: PipelineLogContext, exc: Exception) -> None:
|
||||
_safe_log(logger.error, LogMessages.EMBEDDING_FAILED, ctx, exc_info=exc, error=exc)
|
||||
|
||||
|
||||
def log_chunking_overflow(ctx: PipelineLogContext, exc: Exception) -> None:
|
||||
_safe_log(logger.error, LogMessages.CHUNKING_OVERFLOW, ctx, exc_info=exc, error=exc)
|
||||
|
||||
|
||||
def log_unexpected_error(ctx: PipelineLogContext, exc: Exception) -> None:
|
||||
_safe_log(logger.error, LogMessages.UNEXPECTED, ctx, exc_info=exc, error=exc)
|
||||
|
|
@ -36,6 +36,7 @@ from .podcasts_routes import router as podcasts_router
|
|||
from .public_chat_routes import router as public_chat_router
|
||||
from .rbac_routes import router as rbac_router
|
||||
from .reports_routes import router as reports_router
|
||||
from .sandbox_routes import router as sandbox_router
|
||||
from .search_source_connectors_routes import router as search_source_connectors_router
|
||||
from .search_spaces_routes import router as search_spaces_router
|
||||
from .slack_add_connector_route import router as slack_add_connector_router
|
||||
|
|
@ -50,6 +51,7 @@ router.include_router(editor_router)
|
|||
router.include_router(documents_router)
|
||||
router.include_router(notes_router)
|
||||
router.include_router(new_chat_router) # Chat with assistant-ui persistence
|
||||
router.include_router(sandbox_router) # Sandbox file downloads (Daytona)
|
||||
router.include_router(chat_comments_router)
|
||||
router.include_router(podcasts_router) # Podcast task status and audio
|
||||
router.include_router(reports_router) # Report CRUD and export (PDF/DOCX)
|
||||
|
|
|
|||
|
|
@ -44,6 +44,10 @@ os.environ["UNSTRUCTURED_HAS_PATCHED_LOOP"] = "1"
|
|||
|
||||
router = APIRouter()
|
||||
|
||||
MAX_FILES_PER_UPLOAD = 10
|
||||
MAX_FILE_SIZE_BYTES = 50 * 1024 * 1024 # 50 MB per file
|
||||
MAX_TOTAL_SIZE_BYTES = 200 * 1024 * 1024 # 200 MB total
|
||||
|
||||
|
||||
@router.post("/documents")
|
||||
async def create_documents(
|
||||
|
|
@ -148,12 +152,37 @@ async def create_documents_file_upload(
|
|||
if not files:
|
||||
raise HTTPException(status_code=400, detail="No files provided")
|
||||
|
||||
if len(files) > MAX_FILES_PER_UPLOAD:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"Too many files. Maximum {MAX_FILES_PER_UPLOAD} files per upload.",
|
||||
)
|
||||
|
||||
total_size = 0
|
||||
for file in files:
|
||||
file_size = file.size or 0
|
||||
if file_size > MAX_FILE_SIZE_BYTES:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"File '{file.filename}' ({file_size / (1024 * 1024):.1f} MB) "
|
||||
f"exceeds the {MAX_FILE_SIZE_BYTES // (1024 * 1024)} MB per-file limit.",
|
||||
)
|
||||
total_size += file_size
|
||||
|
||||
if total_size > MAX_TOTAL_SIZE_BYTES:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"Total upload size ({total_size / (1024 * 1024):.1f} MB) "
|
||||
f"exceeds the {MAX_TOTAL_SIZE_BYTES // (1024 * 1024)} MB limit.",
|
||||
)
|
||||
|
||||
created_documents: list[Document] = []
|
||||
files_to_process: list[
|
||||
tuple[Document, str, str]
|
||||
] = [] # (document, temp_path, filename)
|
||||
skipped_duplicates = 0
|
||||
duplicate_document_ids: list[int] = []
|
||||
actual_total_size = 0
|
||||
|
||||
# ===== PHASE 1: Create pending documents for all files =====
|
||||
# This makes ALL documents visible in the UI immediately with pending status
|
||||
|
|
@ -169,11 +198,28 @@ async def create_documents_file_upload(
|
|||
temp_path = temp_file.name
|
||||
|
||||
content = await file.read()
|
||||
file_size = len(content)
|
||||
|
||||
if file_size > MAX_FILE_SIZE_BYTES:
|
||||
os.unlink(temp_path)
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"File '{file.filename}' ({file_size / (1024 * 1024):.1f} MB) "
|
||||
f"exceeds the {MAX_FILE_SIZE_BYTES // (1024 * 1024)} MB per-file limit.",
|
||||
)
|
||||
|
||||
actual_total_size += file_size
|
||||
if actual_total_size > MAX_TOTAL_SIZE_BYTES:
|
||||
os.unlink(temp_path)
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"Total upload size ({actual_total_size / (1024 * 1024):.1f} MB) "
|
||||
f"exceeds the {MAX_TOTAL_SIZE_BYTES // (1024 * 1024)} MB limit.",
|
||||
)
|
||||
|
||||
with open(temp_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
file_size = len(content)
|
||||
|
||||
# Generate unique identifier for deduplication check
|
||||
unique_identifier_hash = generate_unique_identifier_hash(
|
||||
DocumentType.FILE, file.filename or "unknown", search_space_id
|
||||
|
|
@ -373,10 +419,11 @@ async def read_documents(
|
|||
# Convert database objects to API-friendly format
|
||||
api_documents = []
|
||||
for doc in db_documents:
|
||||
# Get user name (display_name or email fallback)
|
||||
created_by_name = None
|
||||
created_by_email = None
|
||||
if doc.created_by:
|
||||
created_by_name = doc.created_by.display_name or doc.created_by.email
|
||||
created_by_name = doc.created_by.display_name
|
||||
created_by_email = doc.created_by.email
|
||||
|
||||
# Parse status from JSONB
|
||||
status_data = None
|
||||
|
|
@ -400,6 +447,7 @@ async def read_documents(
|
|||
search_space_id=doc.search_space_id,
|
||||
created_by_id=doc.created_by_id,
|
||||
created_by_name=created_by_name,
|
||||
created_by_email=created_by_email,
|
||||
status=status_data,
|
||||
)
|
||||
)
|
||||
|
|
@ -528,10 +576,11 @@ async def search_documents(
|
|||
# Convert database objects to API-friendly format
|
||||
api_documents = []
|
||||
for doc in db_documents:
|
||||
# Get user name (display_name or email fallback)
|
||||
created_by_name = None
|
||||
created_by_email = None
|
||||
if doc.created_by:
|
||||
created_by_name = doc.created_by.display_name or doc.created_by.email
|
||||
created_by_name = doc.created_by.display_name
|
||||
created_by_email = doc.created_by.email
|
||||
|
||||
# Parse status from JSONB
|
||||
status_data = None
|
||||
|
|
@ -555,6 +604,7 @@ async def search_documents(
|
|||
search_space_id=doc.search_space_id,
|
||||
created_by_id=doc.created_by_id,
|
||||
created_by_name=created_by_name,
|
||||
created_by_email=created_by_email,
|
||||
status=status_data,
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -76,9 +76,9 @@ def get_token_encryption() -> TokenEncryption:
|
|||
|
||||
# Google Drive OAuth scopes
|
||||
SCOPES = [
|
||||
"https://www.googleapis.com/auth/drive.readonly", # Read-only access to Drive
|
||||
"https://www.googleapis.com/auth/userinfo.email", # User email
|
||||
"https://www.googleapis.com/auth/userinfo.profile", # User profile
|
||||
"https://www.googleapis.com/auth/drive",
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
"openid",
|
||||
]
|
||||
|
||||
|
|
@ -151,6 +151,75 @@ async def connect_drive(space_id: int, user: User = Depends(current_active_user)
|
|||
) from e
|
||||
|
||||
|
||||
@router.get("/auth/google/drive/connector/reauth")
|
||||
async def reauth_drive(
|
||||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""
|
||||
Initiate Google Drive re-authentication to upgrade OAuth scopes.
|
||||
|
||||
Query params:
|
||||
space_id: Search space ID the connector belongs to
|
||||
connector_id: ID of the existing connector to re-authenticate
|
||||
|
||||
Returns:
|
||||
JSON with auth_url to redirect user to Google authorization
|
||||
"""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Google Drive connector not found or access denied",
|
||||
)
|
||||
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="SECRET_KEY not configured for OAuth security."
|
||||
)
|
||||
|
||||
flow = get_google_flow()
|
||||
|
||||
state_manager = get_state_manager()
|
||||
extra: dict = {"connector_id": connector_id}
|
||||
if return_url and return_url.startswith("/"):
|
||||
extra["return_url"] = return_url
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
|
||||
|
||||
auth_url, _ = flow.authorization_url(
|
||||
access_type="offline",
|
||||
prompt="consent",
|
||||
include_granted_scopes="true",
|
||||
state=state_encoded,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Initiating Google Drive re-auth for user {user.id}, connector {connector_id}"
|
||||
)
|
||||
return {"auth_url": auth_url}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initiate Google Drive re-auth: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to initiate Google re-auth: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/auth/google/drive/connector/callback")
|
||||
async def drive_callback(
|
||||
request: Request,
|
||||
|
|
@ -214,6 +283,8 @@ async def drive_callback(
|
|||
|
||||
user_id = UUID(data["user_id"])
|
||||
space_id = data["space_id"]
|
||||
reauth_connector_id = data.get("connector_id")
|
||||
reauth_return_url = data.get("return_url")
|
||||
|
||||
logger.info(
|
||||
f"Processing Google Drive callback for user {user_id}, space {space_id}"
|
||||
|
|
@ -253,7 +324,45 @@ async def drive_callback(
|
|||
# Mark that credentials are encrypted for backward compatibility
|
||||
creds_dict["_token_encrypted"] = True
|
||||
|
||||
# Check for duplicate connector (same account already connected)
|
||||
if reauth_connector_id:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == reauth_connector_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
db_connector = result.scalars().first()
|
||||
if not db_connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Connector not found or access denied during re-auth",
|
||||
)
|
||||
|
||||
existing_start_page_token = db_connector.config.get("start_page_token")
|
||||
db_connector.config = {
|
||||
**creds_dict,
|
||||
"start_page_token": existing_start_page_token,
|
||||
}
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
flag_modified(db_connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(db_connector)
|
||||
|
||||
logger.info(
|
||||
f"Re-authenticated Google Drive connector {db_connector.id} for user {user_id}"
|
||||
)
|
||||
if reauth_return_url and reauth_return_url.startswith("/"):
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=google-drive-connector&connectorId={db_connector.id}"
|
||||
)
|
||||
|
||||
is_duplicate = await check_duplicate_connector(
|
||||
session,
|
||||
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ These endpoints support the ThreadHistoryAdapter pattern from assistant-ui:
|
|||
- POST /threads/{thread_id}/messages - Append message
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
|
|
@ -52,9 +54,47 @@ from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
|
|||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _try_delete_sandbox(thread_id: int) -> None:
|
||||
"""Fire-and-forget sandbox + local file deletion so the HTTP response isn't blocked."""
|
||||
from app.agents.new_chat.sandbox import (
|
||||
delete_local_sandbox_files,
|
||||
delete_sandbox,
|
||||
is_sandbox_enabled,
|
||||
)
|
||||
|
||||
if not is_sandbox_enabled():
|
||||
return
|
||||
|
||||
async def _bg() -> None:
|
||||
try:
|
||||
await delete_sandbox(thread_id)
|
||||
except Exception:
|
||||
_logger.warning(
|
||||
"Background sandbox delete failed for thread %s",
|
||||
thread_id,
|
||||
exc_info=True,
|
||||
)
|
||||
try:
|
||||
delete_local_sandbox_files(thread_id)
|
||||
except Exception:
|
||||
_logger.warning(
|
||||
"Local sandbox file cleanup failed for thread %s",
|
||||
thread_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.create_task(_bg())
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
|
||||
async def check_thread_access(
|
||||
session: AsyncSession,
|
||||
thread: NewChatThread,
|
||||
|
|
@ -648,6 +688,9 @@ async def delete_thread(
|
|||
|
||||
await session.delete(db_thread)
|
||||
await session.commit()
|
||||
|
||||
_try_delete_sandbox(thread_id)
|
||||
|
||||
return {"message": "Thread deleted successfully"}
|
||||
|
||||
except HTTPException:
|
||||
|
|
|
|||
105
surfsense_backend/app/routes/sandbox_routes.py
Normal file
105
surfsense_backend/app/routes/sandbox_routes.py
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
"""Routes for downloading files from Daytona sandbox environments."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import Response
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import NewChatThread, Permission, User, get_async_session
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
MIME_TYPES: dict[str, str] = {
|
||||
".png": "image/png",
|
||||
".jpg": "image/jpeg",
|
||||
".jpeg": "image/jpeg",
|
||||
".gif": "image/gif",
|
||||
".webp": "image/webp",
|
||||
".svg": "image/svg+xml",
|
||||
".pdf": "application/pdf",
|
||||
".csv": "text/csv",
|
||||
".json": "application/json",
|
||||
".txt": "text/plain",
|
||||
".html": "text/html",
|
||||
".md": "text/markdown",
|
||||
".py": "text/x-python",
|
||||
".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
".zip": "application/zip",
|
||||
}
|
||||
|
||||
|
||||
def _guess_media_type(filename: str) -> str:
|
||||
ext = ("." + filename.rsplit(".", 1)[-1].lower()) if "." in filename else ""
|
||||
return MIME_TYPES.get(ext, "application/octet-stream")
|
||||
|
||||
|
||||
@router.get("/threads/{thread_id}/sandbox/download")
|
||||
async def download_sandbox_file(
|
||||
thread_id: int,
|
||||
path: str = Query(..., description="Absolute path of the file inside the sandbox"),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Download a file from the Daytona sandbox associated with a chat thread."""
|
||||
|
||||
from app.agents.new_chat.sandbox import get_or_create_sandbox, is_sandbox_enabled
|
||||
|
||||
if not is_sandbox_enabled():
|
||||
raise HTTPException(status_code=404, detail="Sandbox is not enabled")
|
||||
|
||||
result = await session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == thread_id)
|
||||
)
|
||||
thread = result.scalars().first()
|
||||
if not thread:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
thread.search_space_id,
|
||||
Permission.CHATS_READ.value,
|
||||
"You don't have permission to access files in this thread",
|
||||
)
|
||||
|
||||
from app.agents.new_chat.sandbox import get_local_sandbox_file
|
||||
|
||||
# Prefer locally-persisted copy (sandbox may already be deleted)
|
||||
local_content = get_local_sandbox_file(thread_id, path)
|
||||
if local_content is not None:
|
||||
filename = path.rsplit("/", 1)[-1] if "/" in path else path
|
||||
media_type = _guess_media_type(filename)
|
||||
return Response(
|
||||
content=local_content,
|
||||
media_type=media_type,
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
|
||||
# Fall back to live sandbox download
|
||||
try:
|
||||
sandbox = await get_or_create_sandbox(thread_id)
|
||||
raw_sandbox = sandbox._sandbox
|
||||
content: bytes = await asyncio.to_thread(raw_sandbox.fs.download_file, path)
|
||||
except Exception as exc:
|
||||
logger.warning("Sandbox file download failed for %s: %s", path, exc)
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Could not download file: {exc}"
|
||||
) from exc
|
||||
|
||||
filename = path.rsplit("/", 1)[-1] if "/" in path else path
|
||||
media_type = _guess_media_type(filename)
|
||||
|
||||
return Response(
|
||||
content=content,
|
||||
media_type=media_type,
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
|
|
@ -60,9 +60,8 @@ class DocumentRead(BaseModel):
|
|||
updated_at: datetime | None
|
||||
search_space_id: int
|
||||
created_by_id: UUID | None = None # User who created/uploaded this document
|
||||
created_by_name: str | None = (
|
||||
None # Display name or email of the user who created this document
|
||||
)
|
||||
created_by_name: str | None = None
|
||||
created_by_email: str | None = None
|
||||
status: DocumentStatusSchema | None = (
|
||||
None # Processing status (ready, processing, failed)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1303,10 +1303,9 @@ class ConnectorService:
|
|||
|
||||
sources_list = self._build_chunk_sources_from_documents(
|
||||
github_docs,
|
||||
description_fn=lambda chunk, _doc_info, metadata: metadata.get(
|
||||
"description"
|
||||
)
|
||||
or chunk.get("content", ""),
|
||||
description_fn=lambda chunk, _doc_info, metadata: (
|
||||
metadata.get("description") or chunk.get("content", "")
|
||||
),
|
||||
url_fn=lambda _doc_info, metadata: metadata.get("url", "") or "",
|
||||
)
|
||||
|
||||
|
|
|
|||
11
surfsense_backend/app/services/google_drive/__init__.py
Normal file
11
surfsense_backend/app/services/google_drive/__init__.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
from app.services.google_drive.tool_metadata_service import (
|
||||
GoogleDriveAccount,
|
||||
GoogleDriveFile,
|
||||
GoogleDriveToolMetadataService,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"GoogleDriveAccount",
|
||||
"GoogleDriveFile",
|
||||
"GoogleDriveToolMetadataService",
|
||||
]
|
||||
|
|
@ -0,0 +1,149 @@
|
|||
from dataclasses import dataclass
|
||||
|
||||
from sqlalchemy import and_, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import (
|
||||
Document,
|
||||
DocumentType,
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GoogleDriveAccount:
|
||||
id: int
|
||||
name: str
|
||||
|
||||
@classmethod
|
||||
def from_connector(cls, connector: SearchSourceConnector) -> "GoogleDriveAccount":
|
||||
return cls(id=connector.id, name=connector.name)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {"id": self.id, "name": self.name}
|
||||
|
||||
|
||||
@dataclass
|
||||
class GoogleDriveFile:
|
||||
file_id: str
|
||||
name: str
|
||||
mime_type: str
|
||||
web_view_link: str
|
||||
connector_id: int
|
||||
document_id: int
|
||||
|
||||
@classmethod
|
||||
def from_document(cls, document: Document) -> "GoogleDriveFile":
|
||||
meta = document.document_metadata or {}
|
||||
return cls(
|
||||
file_id=meta.get("google_drive_file_id", ""),
|
||||
name=meta.get("google_drive_file_name", document.title),
|
||||
mime_type=meta.get("google_drive_mime_type", ""),
|
||||
web_view_link=meta.get("web_view_link", ""),
|
||||
connector_id=document.connector_id,
|
||||
document_id=document.id,
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"file_id": self.file_id,
|
||||
"name": self.name,
|
||||
"mime_type": self.mime_type,
|
||||
"web_view_link": self.web_view_link,
|
||||
"connector_id": self.connector_id,
|
||||
"document_id": self.document_id,
|
||||
}
|
||||
|
||||
|
||||
class GoogleDriveToolMetadataService:
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self._db_session = db_session
|
||||
|
||||
async def get_creation_context(self, search_space_id: int, user_id: str) -> dict:
|
||||
accounts = await self._get_google_drive_accounts(search_space_id, user_id)
|
||||
|
||||
if not accounts:
|
||||
return {
|
||||
"accounts": [],
|
||||
"supported_types": [],
|
||||
"error": "No Google Drive account connected",
|
||||
}
|
||||
|
||||
return {
|
||||
"accounts": [acc.to_dict() for acc in accounts],
|
||||
"supported_types": ["google_doc", "google_sheet"],
|
||||
}
|
||||
|
||||
async def get_trash_context(
|
||||
self, search_space_id: int, user_id: str, file_name: str
|
||||
) -> dict:
|
||||
result = await self._db_session.execute(
|
||||
select(Document)
|
||||
.join(
|
||||
SearchSourceConnector, Document.connector_id == SearchSourceConnector.id
|
||||
)
|
||||
.filter(
|
||||
and_(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.document_type == DocumentType.GOOGLE_DRIVE_FILE,
|
||||
func.lower(Document.title) == func.lower(file_name),
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
document = result.scalars().first()
|
||||
|
||||
if not document:
|
||||
return {
|
||||
"error": (
|
||||
f"File '{file_name}' not found in your indexed Google Drive files. "
|
||||
"This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, "
|
||||
"or (3) the file name is different."
|
||||
)
|
||||
}
|
||||
|
||||
if not document.connector_id:
|
||||
return {"error": "Document has no associated connector"}
|
||||
|
||||
result = await self._db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
and_(
|
||||
SearchSourceConnector.id == document.connector_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
|
||||
if not connector:
|
||||
return {"error": "Connector not found or access denied"}
|
||||
|
||||
account = GoogleDriveAccount.from_connector(connector)
|
||||
file = GoogleDriveFile.from_document(document)
|
||||
|
||||
return {
|
||||
"account": account.to_dict(),
|
||||
"file": file.to_dict(),
|
||||
}
|
||||
|
||||
async def _get_google_drive_accounts(
|
||||
self, search_space_id: int, user_id: str
|
||||
) -> list[GoogleDriveAccount]:
|
||||
result = await self._db_session.execute(
|
||||
select(SearchSourceConnector)
|
||||
.filter(
|
||||
and_(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
.order_by(SearchSourceConnector.last_indexed_at.desc())
|
||||
)
|
||||
connectors = result.scalars().all()
|
||||
return [GoogleDriveAccount.from_connector(c) for c in connectors]
|
||||
|
|
@ -9,14 +9,15 @@ Supports loading LLM configurations from:
|
|||
- NewLLMConfig database table (positive IDs for user-created configs with prompt settings)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import logging
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
|
@ -30,7 +31,13 @@ from app.agents.new_chat.llm_config import (
|
|||
load_agent_config,
|
||||
load_llm_config_from_yaml,
|
||||
)
|
||||
from app.db import ChatVisibility, Document, Report, SurfsenseDocsDocument, async_session_maker
|
||||
from app.db import (
|
||||
ChatVisibility,
|
||||
Document,
|
||||
Report,
|
||||
SurfsenseDocsDocument,
|
||||
async_session_maker,
|
||||
)
|
||||
from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
from app.services.chat_session_state_service import (
|
||||
clear_ai_responding,
|
||||
|
|
@ -187,6 +194,7 @@ class StreamResult:
|
|||
accumulated_text: str = ""
|
||||
is_interrupted: bool = False
|
||||
interrupt_value: dict[str, Any] | None = None
|
||||
sandbox_files: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
async def _stream_agent_events(
|
||||
|
|
@ -404,6 +412,21 @@ async def _stream_agent_events(
|
|||
status="in_progress",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
elif tool_name == "execute":
|
||||
cmd = (
|
||||
tool_input.get("command", "")
|
||||
if isinstance(tool_input, dict)
|
||||
else str(tool_input)
|
||||
)
|
||||
display_cmd = cmd[:80] + ("…" if len(cmd) > 80 else "")
|
||||
last_active_step_title = "Running command"
|
||||
last_active_step_items = [f"$ {display_cmd}"]
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=tool_step_id,
|
||||
title="Running command",
|
||||
status="in_progress",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
else:
|
||||
last_active_step_title = f"Using {tool_name.replace('_', ' ')}"
|
||||
last_active_step_items = []
|
||||
|
|
@ -620,6 +643,32 @@ async def _stream_agent_events(
|
|||
status="completed",
|
||||
items=completed_items,
|
||||
)
|
||||
elif tool_name == "execute":
|
||||
raw_text = (
|
||||
tool_output.get("result", "")
|
||||
if isinstance(tool_output, dict)
|
||||
else str(tool_output)
|
||||
)
|
||||
m = re.match(r"^Exit code:\s*(\d+)", raw_text)
|
||||
exit_code_val = int(m.group(1)) if m else None
|
||||
if exit_code_val is not None and exit_code_val == 0:
|
||||
completed_items = [
|
||||
*last_active_step_items,
|
||||
"Completed successfully",
|
||||
]
|
||||
elif exit_code_val is not None:
|
||||
completed_items = [
|
||||
*last_active_step_items,
|
||||
f"Exit code: {exit_code_val}",
|
||||
]
|
||||
else:
|
||||
completed_items = [*last_active_step_items, "Finished"]
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=original_step_id,
|
||||
title="Running command",
|
||||
status="completed",
|
||||
items=completed_items,
|
||||
)
|
||||
elif tool_name == "ls":
|
||||
if isinstance(tool_output, dict):
|
||||
ls_output = tool_output.get("result", "")
|
||||
|
|
@ -804,6 +853,8 @@ async def _stream_agent_events(
|
|||
"create_linear_issue",
|
||||
"update_linear_issue",
|
||||
"delete_linear_issue",
|
||||
"create_google_drive_file",
|
||||
"delete_google_drive_file",
|
||||
):
|
||||
yield streaming_service.format_tool_output_available(
|
||||
tool_call_id,
|
||||
|
|
@ -811,6 +862,36 @@ async def _stream_agent_events(
|
|||
if isinstance(tool_output, dict)
|
||||
else {"result": tool_output},
|
||||
)
|
||||
elif tool_name == "execute":
|
||||
raw_text = (
|
||||
tool_output.get("result", "")
|
||||
if isinstance(tool_output, dict)
|
||||
else str(tool_output)
|
||||
)
|
||||
exit_code: int | None = None
|
||||
output_text = raw_text
|
||||
m = re.match(r"^Exit code:\s*(\d+)", raw_text)
|
||||
if m:
|
||||
exit_code = int(m.group(1))
|
||||
om = re.search(r"\nOutput:\n([\s\S]*)", raw_text)
|
||||
output_text = om.group(1) if om else ""
|
||||
thread_id_str = config.get("configurable", {}).get("thread_id", "")
|
||||
|
||||
for sf_match in re.finditer(
|
||||
r"^SANDBOX_FILE:\s*(.+)$", output_text, re.MULTILINE
|
||||
):
|
||||
fpath = sf_match.group(1).strip()
|
||||
if fpath and fpath not in result.sandbox_files:
|
||||
result.sandbox_files.append(fpath)
|
||||
|
||||
yield streaming_service.format_tool_output_available(
|
||||
tool_call_id,
|
||||
{
|
||||
"exit_code": exit_code,
|
||||
"output": output_text,
|
||||
"thread_id": thread_id_str,
|
||||
},
|
||||
)
|
||||
else:
|
||||
yield streaming_service.format_tool_output_available(
|
||||
tool_call_id,
|
||||
|
|
@ -879,6 +960,36 @@ async def _stream_agent_events(
|
|||
yield streaming_service.format_interrupt_request(result.interrupt_value)
|
||||
|
||||
|
||||
def _try_persist_and_delete_sandbox(
|
||||
thread_id: int,
|
||||
sandbox_files: list[str],
|
||||
) -> None:
|
||||
"""Fire-and-forget: persist sandbox files locally then delete the sandbox."""
|
||||
from app.agents.new_chat.sandbox import (
|
||||
is_sandbox_enabled,
|
||||
persist_and_delete_sandbox,
|
||||
)
|
||||
|
||||
if not is_sandbox_enabled():
|
||||
return
|
||||
|
||||
async def _run() -> None:
|
||||
try:
|
||||
await persist_and_delete_sandbox(thread_id, sandbox_files)
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
"persist_and_delete_sandbox failed for thread %s",
|
||||
thread_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.create_task(_run())
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
|
||||
async def stream_new_chat(
|
||||
user_query: str,
|
||||
search_space_id: int,
|
||||
|
|
@ -915,6 +1026,7 @@ async def stream_new_chat(
|
|||
str: SSE formatted response strings
|
||||
"""
|
||||
streaming_service = VercelStreamingService()
|
||||
stream_result = StreamResult()
|
||||
|
||||
try:
|
||||
# Mark AI as responding to this user for live collaboration
|
||||
|
|
@ -975,6 +1087,22 @@ async def stream_new_chat(
|
|||
# Get the PostgreSQL checkpointer for persistent conversation memory
|
||||
checkpointer = await get_checkpointer()
|
||||
|
||||
# Optionally provision a sandboxed code execution environment
|
||||
sandbox_backend = None
|
||||
from app.agents.new_chat.sandbox import (
|
||||
get_or_create_sandbox,
|
||||
is_sandbox_enabled,
|
||||
)
|
||||
|
||||
if is_sandbox_enabled():
|
||||
try:
|
||||
sandbox_backend = await get_or_create_sandbox(chat_id)
|
||||
except Exception as sandbox_err:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Sandbox creation failed, continuing without execute tool: %s",
|
||||
sandbox_err,
|
||||
)
|
||||
|
||||
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||
agent = await create_surfsense_deep_agent(
|
||||
llm=llm,
|
||||
|
|
@ -987,6 +1115,7 @@ async def stream_new_chat(
|
|||
agent_config=agent_config,
|
||||
firecrawl_api_key=firecrawl_api_key,
|
||||
thread_visibility=visibility,
|
||||
sandbox_backend=sandbox_backend,
|
||||
)
|
||||
|
||||
# Build input with message history
|
||||
|
|
@ -1180,7 +1309,6 @@ async def stream_new_chat(
|
|||
items=initial_items,
|
||||
)
|
||||
|
||||
stream_result = StreamResult()
|
||||
async for sse in _stream_agent_events(
|
||||
agent=agent,
|
||||
config=config,
|
||||
|
|
@ -1294,6 +1422,8 @@ async def stream_new_chat(
|
|||
"Failed to clear AI responding state for thread %s", chat_id
|
||||
)
|
||||
|
||||
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
|
||||
|
||||
|
||||
async def stream_resume_chat(
|
||||
chat_id: int,
|
||||
|
|
@ -1305,6 +1435,7 @@ async def stream_resume_chat(
|
|||
thread_visibility: ChatVisibility | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
streaming_service = VercelStreamingService()
|
||||
stream_result = StreamResult()
|
||||
|
||||
try:
|
||||
if user_id:
|
||||
|
|
@ -1352,6 +1483,22 @@ async def stream_resume_chat(
|
|||
firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY")
|
||||
|
||||
checkpointer = await get_checkpointer()
|
||||
|
||||
sandbox_backend = None
|
||||
from app.agents.new_chat.sandbox import (
|
||||
get_or_create_sandbox,
|
||||
is_sandbox_enabled,
|
||||
)
|
||||
|
||||
if is_sandbox_enabled():
|
||||
try:
|
||||
sandbox_backend = await get_or_create_sandbox(chat_id)
|
||||
except Exception as sandbox_err:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Sandbox creation failed, continuing without execute tool: %s",
|
||||
sandbox_err,
|
||||
)
|
||||
|
||||
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||
|
||||
agent = await create_surfsense_deep_agent(
|
||||
|
|
@ -1365,6 +1512,7 @@ async def stream_resume_chat(
|
|||
agent_config=agent_config,
|
||||
firecrawl_api_key=firecrawl_api_key,
|
||||
thread_visibility=visibility,
|
||||
sandbox_backend=sandbox_backend,
|
||||
)
|
||||
|
||||
# Release the transaction before streaming (same rationale as stream_new_chat).
|
||||
|
|
@ -1380,7 +1528,6 @@ async def stream_resume_chat(
|
|||
yield streaming_service.format_message_start()
|
||||
yield streaming_service.format_start_step()
|
||||
|
||||
stream_result = StreamResult()
|
||||
async for sse in _stream_agent_events(
|
||||
agent=agent,
|
||||
config=config,
|
||||
|
|
@ -1423,3 +1570,5 @@ async def stream_resume_chat(
|
|||
logging.getLogger(__name__).warning(
|
||||
"Failed to clear AI responding state for thread %s", chat_id
|
||||
)
|
||||
|
||||
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.config import config as app_config
|
||||
from app.db import Document, DocumentStatus, DocumentType, Log, Notification
|
||||
from app.indexing_pipeline.adapters.file_upload_adapter import index_uploaded_file
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.notification_service import NotificationService
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
|
|
@ -33,7 +34,6 @@ from .base import (
|
|||
check_document_by_unique_identifier,
|
||||
check_duplicate_document,
|
||||
get_current_timestamp,
|
||||
safe_set_chunks,
|
||||
)
|
||||
from .markdown_processor import add_received_markdown_file_document
|
||||
|
||||
|
|
@ -1632,6 +1632,8 @@ async def process_file_in_background_with_document(
|
|||
from app.config import config as app_config
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
doc_id = document.id
|
||||
|
||||
try:
|
||||
markdown_content = None
|
||||
etl_service = None
|
||||
|
|
@ -1855,7 +1857,7 @@ async def process_file_in_background_with_document(
|
|||
content_hash = generate_content_hash(markdown_content, search_space_id)
|
||||
|
||||
existing_by_content = await check_duplicate_document(session, content_hash)
|
||||
if existing_by_content and existing_by_content.id != document.id:
|
||||
if existing_by_content and existing_by_content.id != doc_id:
|
||||
# Duplicate content found - mark this document as failed
|
||||
logging.info(
|
||||
f"Duplicate content detected for {filename}, "
|
||||
|
|
@ -1863,7 +1865,7 @@ async def process_file_in_background_with_document(
|
|||
)
|
||||
return None
|
||||
|
||||
# ===== STEP 3: Generate embeddings and chunks =====
|
||||
# ===== STEP 3+4: Index via pipeline =====
|
||||
if notification:
|
||||
await NotificationService.document_processing.notify_processing_progress(
|
||||
session, notification, stage="chunking"
|
||||
|
|
@ -1871,57 +1873,23 @@ async def process_file_in_background_with_document(
|
|||
|
||||
user_llm = await get_user_long_context_llm(session, user_id, search_space_id)
|
||||
|
||||
if user_llm:
|
||||
document_metadata = {
|
||||
"file_name": filename,
|
||||
"etl_service": etl_service,
|
||||
"document_type": "File Document",
|
||||
}
|
||||
summary_content, summary_embedding = await generate_document_summary(
|
||||
markdown_content, user_llm, document_metadata
|
||||
)
|
||||
else:
|
||||
# Fallback: use truncated content as summary
|
||||
summary_content = markdown_content[:4000]
|
||||
from app.config import config
|
||||
|
||||
summary_embedding = config.embedding_model_instance.embed(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(markdown_content)
|
||||
|
||||
# ===== STEP 4: Update document to READY =====
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
document.title = filename
|
||||
document.content = summary_content
|
||||
document.content_hash = content_hash
|
||||
document.embedding = summary_embedding
|
||||
document.document_metadata = {
|
||||
"FILE_NAME": filename,
|
||||
"ETL_SERVICE": etl_service or "UNKNOWN",
|
||||
**(document.document_metadata or {}),
|
||||
}
|
||||
flag_modified(document, "document_metadata")
|
||||
|
||||
# Use safe_set_chunks to avoid async issues
|
||||
safe_set_chunks(document, chunks)
|
||||
|
||||
document.source_markdown = markdown_content
|
||||
document.content_needs_reindexing = False
|
||||
document.updated_at = get_current_timestamp()
|
||||
document.status = DocumentStatus.ready() # Shows checkmark in UI
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(document)
|
||||
await index_uploaded_file(
|
||||
markdown_content=markdown_content,
|
||||
filename=filename,
|
||||
etl_service=etl_service,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
llm=user_llm,
|
||||
)
|
||||
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Successfully processed file: {filename}",
|
||||
{
|
||||
"document_id": document.id,
|
||||
"document_id": doc_id,
|
||||
"content_hash": content_hash,
|
||||
"file_type": etl_service,
|
||||
"chunks_count": len(chunks),
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -1946,7 +1914,7 @@ async def process_file_in_background_with_document(
|
|||
{
|
||||
"error_type": type(e).__name__,
|
||||
"filename": filename,
|
||||
"document_id": document.id,
|
||||
"document_id": doc_id,
|
||||
},
|
||||
)
|
||||
logging.error(f"Error processing file with document: {error_message}")
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ dependencies = [
|
|||
"kokoro>=0.9.4",
|
||||
"linkup-sdk>=0.2.4",
|
||||
"llama-cloud-services>=0.6.25",
|
||||
"Markdown>=3.7",
|
||||
"markdownify>=0.14.1",
|
||||
"notion-client>=2.3.0",
|
||||
"numpy>=1.24.0",
|
||||
|
|
@ -65,11 +66,16 @@ dependencies = [
|
|||
"pypandoc_binary>=1.16.2",
|
||||
"typst>=0.14.0",
|
||||
"deepagents>=0.4.3",
|
||||
"langchain-daytona>=0.0.2",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"ruff>=0.12.5",
|
||||
"pytest>=9.0.2",
|
||||
"pytest-asyncio>=1.3.0",
|
||||
"pytest-mock>=3.14",
|
||||
"httpx>=0.28.1",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
|
|
@ -157,10 +163,28 @@ line-ending = "auto"
|
|||
|
||||
[tool.ruff.lint.isort]
|
||||
# Group imports by type
|
||||
known-first-party = ["app"]
|
||||
known-first-party = ["app", "tests"]
|
||||
force-single-line = false
|
||||
combine-as-imports = true
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "session"
|
||||
asyncio_default_test_loop_scope = "session"
|
||||
testpaths = ["tests"]
|
||||
python_files = ["test_*.py"]
|
||||
python_classes = ["Test*"]
|
||||
python_functions = ["test_*"]
|
||||
addopts = "-v --tb=short -x --strict-markers -ra --durations=5"
|
||||
markers = [
|
||||
"unit: pure logic tests, no DB or external services",
|
||||
"integration: tests that require a real PostgreSQL database",
|
||||
"e2e: tests requiring a running backend and real HTTP calls"
|
||||
]
|
||||
filterwarnings = [
|
||||
"ignore::UserWarning:chonkie",
|
||||
]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["."]
|
||||
include = ["app*", "alembic*"]
|
||||
|
|
|
|||
0
surfsense_backend/tests/__init__.py
Normal file
0
surfsense_backend/tests/__init__.py
Normal file
63
surfsense_backend/tests/conftest.py
Normal file
63
surfsense_backend/tests/conftest.py
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
"""Root conftest — shared fixtures available to all test modules."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from app.db import DocumentType
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
|
||||
load_dotenv(Path(__file__).resolve().parent.parent / ".env")
|
||||
|
||||
# Shared DB URL referenced by both e2e and integration helper functions.
|
||||
DATABASE_URL = os.environ.get(
|
||||
"TEST_DATABASE_URL",
|
||||
os.environ.get("DATABASE_URL", ""),
|
||||
).replace("postgresql+asyncpg://", "postgresql://")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit test fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_user_id() -> str:
|
||||
return "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_search_space_id() -> int:
|
||||
return 1
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_connector_id() -> int:
|
||||
return 42
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_connector_document():
|
||||
"""
|
||||
Generic factory for unit tests. Overridden in tests/integration/conftest.py
|
||||
with real DB-backed IDs for integration tests.
|
||||
"""
|
||||
|
||||
def _make(**overrides):
|
||||
defaults = {
|
||||
"title": "Test Document",
|
||||
"source_markdown": "## Heading\n\nSome content.",
|
||||
"unique_id": "test-id-001",
|
||||
"document_type": DocumentType.CLICKUP_CONNECTOR,
|
||||
"search_space_id": 1,
|
||||
"connector_id": 1,
|
||||
"created_by_id": "00000000-0000-0000-0000-000000000001",
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return ConnectorDocument(**defaults)
|
||||
|
||||
return _make
|
||||
0
surfsense_backend/tests/e2e/__init__.py
Normal file
0
surfsense_backend/tests/e2e/__init__.py
Normal file
198
surfsense_backend/tests/e2e/conftest.py
Normal file
198
surfsense_backend/tests/e2e/conftest.py
Normal file
|
|
@ -0,0 +1,198 @@
|
|||
"""E2e conftest — fixtures that require a running backend + database."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
import asyncpg
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from tests.conftest import DATABASE_URL
|
||||
from tests.utils.helpers import (
|
||||
BACKEND_URL,
|
||||
TEST_EMAIL,
|
||||
auth_headers,
|
||||
delete_document,
|
||||
get_auth_token,
|
||||
get_search_space_id,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backend connectivity fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def backend_url() -> str:
|
||||
return BACKEND_URL
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def auth_token(backend_url: str) -> str:
|
||||
"""Authenticate once per session, registering the user if needed."""
|
||||
async with httpx.AsyncClient(base_url=backend_url, timeout=30.0) as client:
|
||||
return await get_auth_token(client)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def search_space_id(backend_url: str, auth_token: str) -> int:
|
||||
"""Discover the first search space belonging to the test user."""
|
||||
async with httpx.AsyncClient(base_url=backend_url, timeout=30.0) as client:
|
||||
return await get_search_space_id(client, auth_token)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
async def _purge_test_search_space(
|
||||
search_space_id: int,
|
||||
):
|
||||
"""
|
||||
Delete all documents in the test search space before the session starts.
|
||||
|
||||
Uses direct database access to bypass the API's 409 protection on
|
||||
pending/processing documents. This ensures stuck documents from
|
||||
previous crashed runs are always cleaned up.
|
||||
"""
|
||||
deleted = await _force_delete_documents_db(search_space_id)
|
||||
if deleted:
|
||||
print(
|
||||
f"\n[purge] Deleted {deleted} stale document(s) from search space {search_space_id}"
|
||||
)
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def headers(auth_token: str) -> dict[str, str]:
|
||||
"""Authorization headers reused across all tests in the session."""
|
||||
return auth_headers(auth_token)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client(backend_url: str) -> AsyncGenerator[httpx.AsyncClient]:
|
||||
"""Per-test async HTTP client pointing at the running backend."""
|
||||
async with httpx.AsyncClient(base_url=backend_url, timeout=180.0) as c:
|
||||
yield c
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cleanup_doc_ids() -> list[int]:
|
||||
"""Accumulator for document IDs that should be deleted after the test."""
|
||||
return []
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def _cleanup_documents(
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
):
|
||||
"""
|
||||
Runs after every test. Tries the API first for clean deletes, then
|
||||
falls back to direct DB access for any stuck documents.
|
||||
"""
|
||||
yield
|
||||
|
||||
remaining_ids: list[int] = []
|
||||
for doc_id in cleanup_doc_ids:
|
||||
try:
|
||||
resp = await delete_document(client, headers, doc_id)
|
||||
if resp.status_code == 409:
|
||||
remaining_ids.append(doc_id)
|
||||
except Exception:
|
||||
remaining_ids.append(doc_id)
|
||||
|
||||
if remaining_ids:
|
||||
conn = await asyncpg.connect(DATABASE_URL)
|
||||
try:
|
||||
await conn.execute(
|
||||
"DELETE FROM documents WHERE id = ANY($1::int[])",
|
||||
remaining_ids,
|
||||
)
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Page-limit helpers (direct DB access)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _force_delete_documents_db(search_space_id: int) -> int:
|
||||
"""
|
||||
Bypass the API and delete documents directly from the database.
|
||||
|
||||
This handles stuck documents in pending/processing state that the API
|
||||
refuses to delete (409 Conflict). Chunks are cascade-deleted by the
|
||||
foreign key constraint.
|
||||
|
||||
Returns the number of deleted rows.
|
||||
"""
|
||||
conn = await asyncpg.connect(DATABASE_URL)
|
||||
try:
|
||||
result = await conn.execute(
|
||||
"DELETE FROM documents WHERE search_space_id = $1",
|
||||
search_space_id,
|
||||
)
|
||||
return int(result.split()[-1])
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
|
||||
async def _get_user_page_usage(email: str) -> tuple[int, int]:
|
||||
"""Return ``(pages_used, pages_limit)`` for the given user."""
|
||||
conn = await asyncpg.connect(DATABASE_URL)
|
||||
try:
|
||||
row = await conn.fetchrow(
|
||||
'SELECT pages_used, pages_limit FROM "user" WHERE email = $1',
|
||||
email,
|
||||
)
|
||||
assert row is not None, f"User {email!r} not found in database"
|
||||
return row["pages_used"], row["pages_limit"]
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
|
||||
async def _set_user_page_limits(
|
||||
email: str, *, pages_used: int, pages_limit: int
|
||||
) -> None:
|
||||
"""Overwrite ``pages_used`` and ``pages_limit`` for the given user."""
|
||||
conn = await asyncpg.connect(DATABASE_URL)
|
||||
try:
|
||||
await conn.execute(
|
||||
'UPDATE "user" SET pages_used = $1, pages_limit = $2 WHERE email = $3',
|
||||
pages_used,
|
||||
pages_limit,
|
||||
email,
|
||||
)
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def page_limits():
|
||||
"""
|
||||
Fixture that exposes helpers for manipulating the test user's page limits.
|
||||
|
||||
Automatically restores the original values after each test.
|
||||
|
||||
Usage inside a test::
|
||||
|
||||
await page_limits.set(pages_used=0, pages_limit=100)
|
||||
used, limit = await page_limits.get()
|
||||
"""
|
||||
|
||||
class _PageLimits:
|
||||
async def set(self, *, pages_used: int, pages_limit: int) -> None:
|
||||
await _set_user_page_limits(
|
||||
TEST_EMAIL, pages_used=pages_used, pages_limit=pages_limit
|
||||
)
|
||||
|
||||
async def get(self) -> tuple[int, int]:
|
||||
return await _get_user_page_usage(TEST_EMAIL)
|
||||
|
||||
original = await _get_user_page_usage(TEST_EMAIL)
|
||||
yield _PageLimits()
|
||||
await _set_user_page_limits(
|
||||
TEST_EMAIL, pages_used=original[0], pages_limit=original[1]
|
||||
)
|
||||
592
surfsense_backend/tests/e2e/test_document_upload.py
Normal file
592
surfsense_backend/tests/e2e/test_document_upload.py
Normal file
|
|
@ -0,0 +1,592 @@
|
|||
"""
|
||||
End-to-end tests for manual document upload.
|
||||
|
||||
These tests exercise the full pipeline:
|
||||
API upload → Celery task → ETL extraction → chunking → embedding → DB storage
|
||||
|
||||
Prerequisites (must be running):
|
||||
- FastAPI backend
|
||||
- PostgreSQL + pgvector
|
||||
- Redis
|
||||
- Celery worker
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from tests.utils.helpers import (
|
||||
FIXTURES_DIR,
|
||||
delete_document,
|
||||
get_document,
|
||||
poll_document_status,
|
||||
upload_file,
|
||||
upload_multiple_files,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.e2e
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers local to this module
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _assert_document_ready(doc: dict, *, expected_filename: str) -> None:
|
||||
"""Common assertions for a successfully processed document."""
|
||||
assert doc["title"] == expected_filename
|
||||
assert doc["document_type"] == "FILE"
|
||||
assert doc["content"], "Document content (summary) should not be empty"
|
||||
assert doc["content_hash"], "content_hash should be set"
|
||||
assert doc["document_metadata"].get("FILE_NAME") == expected_filename
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test A: Upload a .txt file (direct read path — no ETL service needed)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTxtFileUpload:
|
||||
"""Upload a plain-text file and verify the full pipeline."""
|
||||
|
||||
async def test_upload_txt_returns_document_id(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
):
|
||||
resp = await upload_file(
|
||||
client, headers, "sample.txt", search_space_id=search_space_id
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
body = resp.json()
|
||||
assert body["pending_files"] >= 1
|
||||
assert len(body["document_ids"]) >= 1
|
||||
cleanup_doc_ids.extend(body["document_ids"])
|
||||
|
||||
async def test_txt_processing_reaches_ready(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
):
|
||||
resp = await upload_file(
|
||||
client, headers, "sample.txt", search_space_id=search_space_id
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
doc_ids = resp.json()["document_ids"]
|
||||
cleanup_doc_ids.extend(doc_ids)
|
||||
|
||||
statuses = await poll_document_status(
|
||||
client, headers, doc_ids, search_space_id=search_space_id
|
||||
)
|
||||
for did in doc_ids:
|
||||
assert statuses[did]["status"]["state"] == "ready"
|
||||
|
||||
async def test_txt_document_fields_populated(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
):
|
||||
resp = await upload_file(
|
||||
client, headers, "sample.txt", search_space_id=search_space_id
|
||||
)
|
||||
doc_ids = resp.json()["document_ids"]
|
||||
cleanup_doc_ids.extend(doc_ids)
|
||||
|
||||
await poll_document_status(
|
||||
client, headers, doc_ids, search_space_id=search_space_id
|
||||
)
|
||||
|
||||
doc = await get_document(client, headers, doc_ids[0])
|
||||
_assert_document_ready(doc, expected_filename="sample.txt")
|
||||
assert doc["document_metadata"]["ETL_SERVICE"] == "MARKDOWN"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test B: Upload a .md file (markdown direct-read path)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMarkdownFileUpload:
|
||||
"""Upload a Markdown file and verify the full pipeline."""
|
||||
|
||||
async def test_md_processing_reaches_ready(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
):
|
||||
resp = await upload_file(
|
||||
client, headers, "sample.md", search_space_id=search_space_id
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
doc_ids = resp.json()["document_ids"]
|
||||
cleanup_doc_ids.extend(doc_ids)
|
||||
|
||||
statuses = await poll_document_status(
|
||||
client, headers, doc_ids, search_space_id=search_space_id
|
||||
)
|
||||
for did in doc_ids:
|
||||
assert statuses[did]["status"]["state"] == "ready"
|
||||
|
||||
async def test_md_document_fields_populated(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
):
|
||||
resp = await upload_file(
|
||||
client, headers, "sample.md", search_space_id=search_space_id
|
||||
)
|
||||
doc_ids = resp.json()["document_ids"]
|
||||
cleanup_doc_ids.extend(doc_ids)
|
||||
|
||||
await poll_document_status(
|
||||
client, headers, doc_ids, search_space_id=search_space_id
|
||||
)
|
||||
|
||||
doc = await get_document(client, headers, doc_ids[0])
|
||||
_assert_document_ready(doc, expected_filename="sample.md")
|
||||
assert doc["document_metadata"]["ETL_SERVICE"] == "MARKDOWN"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test C: Upload a .pdf file (ETL path — Docling / Unstructured)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPdfFileUpload:
|
||||
"""Upload a PDF and verify it goes through the ETL extraction pipeline."""
|
||||
|
||||
async def test_pdf_processing_reaches_ready(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
):
|
||||
resp = await upload_file(
|
||||
client, headers, "sample.pdf", search_space_id=search_space_id
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
doc_ids = resp.json()["document_ids"]
|
||||
cleanup_doc_ids.extend(doc_ids)
|
||||
|
||||
statuses = await poll_document_status(
|
||||
client, headers, doc_ids, search_space_id=search_space_id, timeout=300.0
|
||||
)
|
||||
for did in doc_ids:
|
||||
assert statuses[did]["status"]["state"] == "ready"
|
||||
|
||||
async def test_pdf_document_fields_populated(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
):
|
||||
resp = await upload_file(
|
||||
client, headers, "sample.pdf", search_space_id=search_space_id
|
||||
)
|
||||
doc_ids = resp.json()["document_ids"]
|
||||
cleanup_doc_ids.extend(doc_ids)
|
||||
|
||||
await poll_document_status(
|
||||
client, headers, doc_ids, search_space_id=search_space_id, timeout=300.0
|
||||
)
|
||||
|
||||
doc = await get_document(client, headers, doc_ids[0])
|
||||
_assert_document_ready(doc, expected_filename="sample.pdf")
|
||||
assert doc["document_metadata"]["ETL_SERVICE"] in {
|
||||
"DOCLING",
|
||||
"UNSTRUCTURED",
|
||||
"LLAMACLOUD",
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test D: Upload multiple files in a single request
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMultiFileUpload:
|
||||
"""Upload several files at once and verify all are processed."""
|
||||
|
||||
async def test_multi_upload_returns_all_ids(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
):
|
||||
resp = await upload_multiple_files(
|
||||
client,
|
||||
headers,
|
||||
["sample.txt", "sample.md"],
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
body = resp.json()
|
||||
assert body["pending_files"] == 2
|
||||
assert len(body["document_ids"]) == 2
|
||||
cleanup_doc_ids.extend(body["document_ids"])
|
||||
|
||||
async def test_multi_upload_all_reach_ready(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
):
|
||||
resp = await upload_multiple_files(
|
||||
client,
|
||||
headers,
|
||||
["sample.txt", "sample.md"],
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
doc_ids = resp.json()["document_ids"]
|
||||
cleanup_doc_ids.extend(doc_ids)
|
||||
|
||||
statuses = await poll_document_status(
|
||||
client, headers, doc_ids, search_space_id=search_space_id
|
||||
)
|
||||
for did in doc_ids:
|
||||
assert statuses[did]["status"]["state"] == "ready"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test E: Duplicate file upload (same file uploaded twice)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDuplicateFileUpload:
|
||||
"""
|
||||
Uploading the exact same file a second time should be detected as a
|
||||
duplicate via ``unique_identifier_hash``.
|
||||
"""
|
||||
|
||||
async def test_duplicate_file_is_skipped(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
):
|
||||
# First upload
|
||||
resp1 = await upload_file(
|
||||
client, headers, "sample.txt", search_space_id=search_space_id
|
||||
)
|
||||
assert resp1.status_code == 200
|
||||
first_ids = resp1.json()["document_ids"]
|
||||
cleanup_doc_ids.extend(first_ids)
|
||||
|
||||
await poll_document_status(
|
||||
client, headers, first_ids, search_space_id=search_space_id
|
||||
)
|
||||
|
||||
# Second upload of the same file
|
||||
resp2 = await upload_file(
|
||||
client, headers, "sample.txt", search_space_id=search_space_id
|
||||
)
|
||||
assert resp2.status_code == 200
|
||||
|
||||
body2 = resp2.json()
|
||||
assert body2["skipped_duplicates"] >= 1
|
||||
assert len(body2["duplicate_document_ids"]) >= 1
|
||||
cleanup_doc_ids.extend(body2.get("document_ids", []))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test F: Duplicate content detection (different name, same content)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDuplicateContentDetection:
|
||||
"""
|
||||
Uploading a file with a different name but identical content should be
|
||||
detected as duplicate content via ``content_hash``.
|
||||
"""
|
||||
|
||||
async def test_same_content_different_name_detected(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
tmp_path: Path,
|
||||
):
|
||||
# First upload
|
||||
resp1 = await upload_file(
|
||||
client, headers, "sample.txt", search_space_id=search_space_id
|
||||
)
|
||||
assert resp1.status_code == 200
|
||||
first_ids = resp1.json()["document_ids"]
|
||||
cleanup_doc_ids.extend(first_ids)
|
||||
await poll_document_status(
|
||||
client, headers, first_ids, search_space_id=search_space_id
|
||||
)
|
||||
|
||||
# Copy fixture content to a differently named temp file
|
||||
src = FIXTURES_DIR / "sample.txt"
|
||||
dest = tmp_path / "renamed_sample.txt"
|
||||
shutil.copy2(src, dest)
|
||||
|
||||
with open(dest, "rb") as f:
|
||||
resp2 = await client.post(
|
||||
"/api/v1/documents/fileupload",
|
||||
headers=headers,
|
||||
files={"files": ("renamed_sample.txt", f)},
|
||||
data={"search_space_id": str(search_space_id)},
|
||||
)
|
||||
assert resp2.status_code == 200
|
||||
second_ids = resp2.json()["document_ids"]
|
||||
cleanup_doc_ids.extend(second_ids)
|
||||
assert second_ids, (
|
||||
"Expected at least one document id for renamed duplicate content upload"
|
||||
)
|
||||
|
||||
statuses = await poll_document_status(
|
||||
client, headers, second_ids, search_space_id=search_space_id
|
||||
)
|
||||
for did in second_ids:
|
||||
assert statuses[did]["status"]["state"] == "failed"
|
||||
assert "duplicate" in statuses[did]["status"].get("reason", "").lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test G: Empty / corrupt file handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEmptyFileUpload:
|
||||
"""An empty file should be processed but ultimately fail gracefully."""
|
||||
|
||||
async def test_empty_pdf_fails(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
):
|
||||
resp = await upload_file(
|
||||
client, headers, "empty.pdf", search_space_id=search_space_id
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
doc_ids = resp.json()["document_ids"]
|
||||
cleanup_doc_ids.extend(doc_ids)
|
||||
assert doc_ids, "Expected at least one document id for empty PDF upload"
|
||||
|
||||
statuses = await poll_document_status(
|
||||
client, headers, doc_ids, search_space_id=search_space_id, timeout=120.0
|
||||
)
|
||||
for did in doc_ids:
|
||||
assert statuses[did]["status"]["state"] == "failed"
|
||||
assert statuses[did]["status"].get("reason"), (
|
||||
"Failed document should include a reason"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test H: Upload without authentication
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestUnauthenticatedUpload:
|
||||
"""Requests without a valid JWT should be rejected."""
|
||||
|
||||
async def test_upload_without_auth_returns_401(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
search_space_id: int,
|
||||
):
|
||||
file_path = FIXTURES_DIR / "sample.txt"
|
||||
with open(file_path, "rb") as f:
|
||||
resp = await client.post(
|
||||
"/api/v1/documents/fileupload",
|
||||
files={"files": ("sample.txt", f)},
|
||||
data={"search_space_id": str(search_space_id)},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test I: Upload with no files attached
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNoFilesUpload:
|
||||
"""Submitting the form with zero files should return a validation error."""
|
||||
|
||||
async def test_no_files_returns_error(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
):
|
||||
resp = await client.post(
|
||||
"/api/v1/documents/fileupload",
|
||||
headers=headers,
|
||||
data={"search_space_id": str(search_space_id)},
|
||||
)
|
||||
assert resp.status_code in {400, 422}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test J: Document deletion after successful upload
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDocumentDeletion:
|
||||
"""Upload, wait for ready, delete, then verify it's gone."""
|
||||
|
||||
async def test_delete_processed_document(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
):
|
||||
resp = await upload_file(
|
||||
client, headers, "sample.txt", search_space_id=search_space_id
|
||||
)
|
||||
doc_ids = resp.json()["document_ids"]
|
||||
await poll_document_status(
|
||||
client, headers, doc_ids, search_space_id=search_space_id
|
||||
)
|
||||
|
||||
del_resp = await delete_document(client, headers, doc_ids[0])
|
||||
assert del_resp.status_code == 200
|
||||
|
||||
get_resp = await client.get(
|
||||
f"/api/v1/documents/{doc_ids[0]}",
|
||||
headers=headers,
|
||||
)
|
||||
assert get_resp.status_code == 404
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test K: Cannot delete a document while it is still processing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDeleteWhileProcessing:
|
||||
"""Attempting to delete a pending/processing document should be rejected."""
|
||||
|
||||
async def test_delete_pending_document_returns_409(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
):
|
||||
resp = await upload_file(
|
||||
client, headers, "sample.pdf", search_space_id=search_space_id
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
doc_ids = resp.json()["document_ids"]
|
||||
cleanup_doc_ids.extend(doc_ids)
|
||||
|
||||
# Immediately try to delete before processing finishes
|
||||
del_resp = await delete_document(client, headers, doc_ids[0])
|
||||
assert del_resp.status_code == 409
|
||||
|
||||
# Let it finish so cleanup can work
|
||||
await poll_document_status(
|
||||
client, headers, doc_ids, search_space_id=search_space_id, timeout=300.0
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test L: Status polling returns correct structure
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDocumentSearchability:
|
||||
"""After upload reaches ready, the document must appear in the title search."""
|
||||
|
||||
async def test_uploaded_document_appears_in_search(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
):
|
||||
resp = await upload_file(
|
||||
client, headers, "sample.txt", search_space_id=search_space_id
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
doc_ids = resp.json()["document_ids"]
|
||||
cleanup_doc_ids.extend(doc_ids)
|
||||
|
||||
await poll_document_status(
|
||||
client, headers, doc_ids, search_space_id=search_space_id
|
||||
)
|
||||
|
||||
search_resp = await client.get(
|
||||
"/api/v1/documents/search",
|
||||
headers=headers,
|
||||
params={"title": "sample", "search_space_id": search_space_id},
|
||||
)
|
||||
assert search_resp.status_code == 200
|
||||
|
||||
result_ids = [d["id"] for d in search_resp.json()["items"]]
|
||||
assert doc_ids[0] in result_ids, (
|
||||
f"Uploaded document {doc_ids[0]} not found in search results: {result_ids}"
|
||||
)
|
||||
|
||||
|
||||
class TestStatusPolling:
|
||||
"""Verify the status endpoint returns well-formed responses."""
|
||||
|
||||
async def test_status_endpoint_returns_items(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
):
|
||||
resp = await upload_file(
|
||||
client, headers, "sample.txt", search_space_id=search_space_id
|
||||
)
|
||||
doc_ids = resp.json()["document_ids"]
|
||||
cleanup_doc_ids.extend(doc_ids)
|
||||
|
||||
status_resp = await client.get(
|
||||
"/api/v1/documents/status",
|
||||
headers=headers,
|
||||
params={
|
||||
"search_space_id": search_space_id,
|
||||
"document_ids": ",".join(str(d) for d in doc_ids),
|
||||
},
|
||||
)
|
||||
assert status_resp.status_code == 200
|
||||
|
||||
body = status_resp.json()
|
||||
assert "items" in body
|
||||
assert len(body["items"]) == len(doc_ids)
|
||||
for item in body["items"]:
|
||||
assert "id" in item
|
||||
assert "status" in item
|
||||
assert "state" in item["status"]
|
||||
assert item["status"]["state"] in {
|
||||
"pending",
|
||||
"processing",
|
||||
"ready",
|
||||
"failed",
|
||||
}
|
||||
|
||||
await poll_document_status(
|
||||
client, headers, doc_ids, search_space_id=search_space_id
|
||||
)
|
||||
323
surfsense_backend/tests/e2e/test_page_limits.py
Normal file
323
surfsense_backend/tests/e2e/test_page_limits.py
Normal file
|
|
@ -0,0 +1,323 @@
|
|||
"""
|
||||
End-to-end tests for page-limit enforcement during document upload.
|
||||
|
||||
These tests manipulate the test user's ``pages_used`` / ``pages_limit``
|
||||
columns directly in the database and then exercise the upload pipeline to
|
||||
verify that:
|
||||
|
||||
- Uploads are rejected *before* ETL when the limit is exhausted.
|
||||
- ``pages_used`` increases after a successful upload.
|
||||
- A ``page_limit_exceeded`` notification is created on rejection.
|
||||
- ``pages_used`` is not modified when a document fails processing.
|
||||
|
||||
All tests reuse the existing small fixtures (``sample.pdf``, ``sample.txt``)
|
||||
so no additional processing time is introduced.
|
||||
|
||||
Prerequisites (must be running):
|
||||
- FastAPI backend
|
||||
- PostgreSQL + pgvector
|
||||
- Redis
|
||||
- Celery worker
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from tests.utils.helpers import (
|
||||
get_notifications,
|
||||
poll_document_status,
|
||||
upload_file,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.e2e
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test A: Successful upload increments pages_used
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPageUsageIncrementsOnSuccess:
|
||||
"""After a successful PDF upload the user's ``pages_used`` must grow."""
|
||||
|
||||
async def test_pages_used_increases_after_pdf_upload(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
page_limits,
|
||||
):
|
||||
await page_limits.set(pages_used=0, pages_limit=1000)
|
||||
|
||||
resp = await upload_file(
|
||||
client, headers, "sample.pdf", search_space_id=search_space_id
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
doc_ids = resp.json()["document_ids"]
|
||||
cleanup_doc_ids.extend(doc_ids)
|
||||
|
||||
statuses = await poll_document_status(
|
||||
client, headers, doc_ids, search_space_id=search_space_id, timeout=300.0
|
||||
)
|
||||
for did in doc_ids:
|
||||
assert statuses[did]["status"]["state"] == "ready"
|
||||
|
||||
used, _ = await page_limits.get()
|
||||
assert used > 0, "pages_used should have increased after successful processing"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test B: Upload rejected when page limit is fully exhausted
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestUploadRejectedWhenLimitExhausted:
|
||||
"""
|
||||
When ``pages_used == pages_limit`` (zero remaining) the document
|
||||
should reach ``failed`` status with a page-limit reason.
|
||||
"""
|
||||
|
||||
async def test_pdf_fails_when_no_pages_remaining(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
page_limits,
|
||||
):
|
||||
await page_limits.set(pages_used=100, pages_limit=100)
|
||||
|
||||
resp = await upload_file(
|
||||
client, headers, "sample.pdf", search_space_id=search_space_id
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
doc_ids = resp.json()["document_ids"]
|
||||
cleanup_doc_ids.extend(doc_ids)
|
||||
|
||||
statuses = await poll_document_status(
|
||||
client, headers, doc_ids, search_space_id=search_space_id, timeout=300.0
|
||||
)
|
||||
for did in doc_ids:
|
||||
assert statuses[did]["status"]["state"] == "failed"
|
||||
reason = statuses[did]["status"].get("reason", "").lower()
|
||||
assert "page limit" in reason, (
|
||||
f"Expected 'page limit' in failure reason, got: {reason!r}"
|
||||
)
|
||||
|
||||
async def test_pages_used_unchanged_after_limit_rejection(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
page_limits,
|
||||
):
|
||||
await page_limits.set(pages_used=50, pages_limit=50)
|
||||
|
||||
resp = await upload_file(
|
||||
client, headers, "sample.pdf", search_space_id=search_space_id
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
doc_ids = resp.json()["document_ids"]
|
||||
cleanup_doc_ids.extend(doc_ids)
|
||||
|
||||
await poll_document_status(
|
||||
client, headers, doc_ids, search_space_id=search_space_id, timeout=300.0
|
||||
)
|
||||
|
||||
used, _ = await page_limits.get()
|
||||
assert used == 50, (
|
||||
f"pages_used should remain 50 after rejected upload, got {used}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test C: Page-limit notification is created on rejection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPageLimitNotification:
|
||||
"""A ``page_limit_exceeded`` notification must be created when upload
|
||||
is rejected due to the limit."""
|
||||
|
||||
async def test_page_limit_exceeded_notification_created(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
page_limits,
|
||||
):
|
||||
await page_limits.set(pages_used=100, pages_limit=100)
|
||||
|
||||
resp = await upload_file(
|
||||
client, headers, "sample.pdf", search_space_id=search_space_id
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
doc_ids = resp.json()["document_ids"]
|
||||
cleanup_doc_ids.extend(doc_ids)
|
||||
|
||||
await poll_document_status(
|
||||
client, headers, doc_ids, search_space_id=search_space_id, timeout=300.0
|
||||
)
|
||||
|
||||
notifications = await get_notifications(
|
||||
client,
|
||||
headers,
|
||||
type_filter="page_limit_exceeded",
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
assert len(notifications) >= 1, (
|
||||
"Expected at least one page_limit_exceeded notification"
|
||||
)
|
||||
|
||||
latest = notifications[0]
|
||||
assert (
|
||||
"page limit" in latest["title"].lower()
|
||||
or "page limit" in latest["message"].lower()
|
||||
), (
|
||||
f"Notification should mention page limit: title={latest['title']!r}, "
|
||||
f"message={latest['message']!r}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test D: Successful upload creates a completed document_processing notification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDocumentProcessingNotification:
|
||||
"""A ``document_processing`` notification with ``completed`` status must
|
||||
exist after a successful upload."""
|
||||
|
||||
async def test_processing_completed_notification_exists(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
page_limits,
|
||||
):
|
||||
await page_limits.set(pages_used=0, pages_limit=1000)
|
||||
|
||||
resp = await upload_file(
|
||||
client, headers, "sample.txt", search_space_id=search_space_id
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
doc_ids = resp.json()["document_ids"]
|
||||
cleanup_doc_ids.extend(doc_ids)
|
||||
|
||||
await poll_document_status(
|
||||
client, headers, doc_ids, search_space_id=search_space_id
|
||||
)
|
||||
|
||||
notifications = await get_notifications(
|
||||
client,
|
||||
headers,
|
||||
type_filter="document_processing",
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
completed = [
|
||||
n
|
||||
for n in notifications
|
||||
if n.get("metadata", {}).get("processing_stage") == "completed"
|
||||
]
|
||||
assert len(completed) >= 1, (
|
||||
"Expected at least one document_processing notification with 'completed' stage"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test E: pages_used unchanged when a document fails for non-limit reasons
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPagesUnchangedOnProcessingFailure:
|
||||
"""If a document fails during ETL (e.g. empty/corrupt file) rather than
|
||||
a page-limit rejection, ``pages_used`` should remain unchanged."""
|
||||
|
||||
async def test_pages_used_stable_on_etl_failure(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
page_limits,
|
||||
):
|
||||
await page_limits.set(pages_used=10, pages_limit=1000)
|
||||
|
||||
resp = await upload_file(
|
||||
client, headers, "empty.pdf", search_space_id=search_space_id
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
doc_ids = resp.json()["document_ids"]
|
||||
cleanup_doc_ids.extend(doc_ids)
|
||||
|
||||
if doc_ids:
|
||||
statuses = await poll_document_status(
|
||||
client, headers, doc_ids, search_space_id=search_space_id, timeout=120.0
|
||||
)
|
||||
for did in doc_ids:
|
||||
assert statuses[did]["status"]["state"] == "failed"
|
||||
|
||||
used, _ = await page_limits.get()
|
||||
assert used == 10, f"pages_used should remain 10 after ETL failure, got {used}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test F: Second upload rejected after first consumes remaining quota
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSecondUploadExceedsLimit:
|
||||
"""Upload one PDF successfully, consuming the quota, then verify a
|
||||
second upload is rejected."""
|
||||
|
||||
async def test_second_upload_rejected_after_quota_consumed(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
page_limits,
|
||||
):
|
||||
# Give just enough room for one ~1-page PDF
|
||||
await page_limits.set(pages_used=0, pages_limit=1)
|
||||
|
||||
resp1 = await upload_file(
|
||||
client, headers, "sample.pdf", search_space_id=search_space_id
|
||||
)
|
||||
assert resp1.status_code == 200
|
||||
first_ids = resp1.json()["document_ids"]
|
||||
cleanup_doc_ids.extend(first_ids)
|
||||
|
||||
statuses1 = await poll_document_status(
|
||||
client, headers, first_ids, search_space_id=search_space_id, timeout=300.0
|
||||
)
|
||||
for did in first_ids:
|
||||
assert statuses1[did]["status"]["state"] == "ready"
|
||||
|
||||
# Second upload — should fail because quota is now consumed
|
||||
resp2 = await upload_file(
|
||||
client,
|
||||
headers,
|
||||
"sample.pdf",
|
||||
search_space_id=search_space_id,
|
||||
filename_override="sample_copy.pdf",
|
||||
)
|
||||
assert resp2.status_code == 200
|
||||
second_ids = resp2.json()["document_ids"]
|
||||
cleanup_doc_ids.extend(second_ids)
|
||||
|
||||
statuses2 = await poll_document_status(
|
||||
client, headers, second_ids, search_space_id=search_space_id, timeout=300.0
|
||||
)
|
||||
for did in second_ids:
|
||||
assert statuses2[did]["status"]["state"] == "failed"
|
||||
reason = statuses2[did]["status"].get("reason", "").lower()
|
||||
assert "page limit" in reason, (
|
||||
f"Expected 'page limit' in failure reason, got: {reason!r}"
|
||||
)
|
||||
146
surfsense_backend/tests/e2e/test_upload_limits.py
Normal file
146
surfsense_backend/tests/e2e/test_upload_limits.py
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
"""
|
||||
End-to-end tests for backend file upload limit enforcement.
|
||||
|
||||
These tests verify that the API rejects uploads that exceed:
|
||||
- Max files per upload (10)
|
||||
- Max per-file size (50 MB)
|
||||
- Max total upload size (200 MB)
|
||||
|
||||
The limits mirror the frontend's DocumentUploadTab.tsx constants and are
|
||||
enforced server-side to protect against direct API calls.
|
||||
|
||||
Prerequisites (must be running):
|
||||
- FastAPI backend
|
||||
- PostgreSQL + pgvector
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.e2e
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test A: File count limit
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFileCountLimit:
|
||||
"""Uploading more than 10 files in a single request should be rejected."""
|
||||
|
||||
async def test_11_files_returns_413(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
):
|
||||
files = [
|
||||
("files", (f"file_{i}.txt", io.BytesIO(b"test content"), "text/plain"))
|
||||
for i in range(11)
|
||||
]
|
||||
resp = await client.post(
|
||||
"/api/v1/documents/fileupload",
|
||||
headers=headers,
|
||||
files=files,
|
||||
data={"search_space_id": str(search_space_id)},
|
||||
)
|
||||
assert resp.status_code == 413
|
||||
assert "too many files" in resp.json()["detail"].lower()
|
||||
|
||||
async def test_10_files_accepted(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
):
|
||||
files = [
|
||||
("files", (f"file_{i}.txt", io.BytesIO(b"test content"), "text/plain"))
|
||||
for i in range(10)
|
||||
]
|
||||
resp = await client.post(
|
||||
"/api/v1/documents/fileupload",
|
||||
headers=headers,
|
||||
files=files,
|
||||
data={"search_space_id": str(search_space_id)},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
cleanup_doc_ids.extend(resp.json().get("document_ids", []))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test B: Per-file size limit
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPerFileSizeLimit:
|
||||
"""A single file exceeding 50 MB should be rejected."""
|
||||
|
||||
async def test_oversized_file_returns_413(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
):
|
||||
oversized = io.BytesIO(b"\x00" * (50 * 1024 * 1024 + 1))
|
||||
resp = await client.post(
|
||||
"/api/v1/documents/fileupload",
|
||||
headers=headers,
|
||||
files=[("files", ("big.pdf", oversized, "application/pdf"))],
|
||||
data={"search_space_id": str(search_space_id)},
|
||||
)
|
||||
assert resp.status_code == 413
|
||||
assert "per-file limit" in resp.json()["detail"].lower()
|
||||
|
||||
async def test_file_at_limit_accepted(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
):
|
||||
at_limit = io.BytesIO(b"\x00" * (50 * 1024 * 1024))
|
||||
resp = await client.post(
|
||||
"/api/v1/documents/fileupload",
|
||||
headers=headers,
|
||||
files=[("files", ("exact50mb.txt", at_limit, "text/plain"))],
|
||||
data={"search_space_id": str(search_space_id)},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
cleanup_doc_ids.extend(resp.json().get("document_ids", []))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test C: Total upload size limit
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTotalSizeLimit:
|
||||
"""Multiple files whose combined size exceeds 200 MB should be rejected."""
|
||||
|
||||
async def test_total_size_over_200mb_returns_413(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
):
|
||||
chunk_size = 45 * 1024 * 1024 # 45 MB each
|
||||
files = [
|
||||
(
|
||||
"files",
|
||||
(f"chunk_{i}.txt", io.BytesIO(b"\x00" * chunk_size), "text/plain"),
|
||||
)
|
||||
for i in range(5) # 5 x 45 MB = 225 MB > 200 MB
|
||||
]
|
||||
resp = await client.post(
|
||||
"/api/v1/documents/fileupload",
|
||||
headers=headers,
|
||||
files=files,
|
||||
data={"search_space_id": str(search_space_id)},
|
||||
)
|
||||
assert resp.status_code == 413
|
||||
assert "total upload size" in resp.json()["detail"].lower()
|
||||
0
surfsense_backend/tests/fixtures/empty.pdf
vendored
Normal file
0
surfsense_backend/tests/fixtures/empty.pdf
vendored
Normal file
51
surfsense_backend/tests/fixtures/sample.md
vendored
Normal file
51
surfsense_backend/tests/fixtures/sample.md
vendored
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
# SurfSense Test Document
|
||||
|
||||
## Overview
|
||||
|
||||
This is a **sample markdown document** used for end-to-end testing of the manual
|
||||
document upload pipeline. It includes various markdown formatting elements.
|
||||
|
||||
## Key Features
|
||||
|
||||
- Document upload and processing
|
||||
- Automatic chunking of content
|
||||
- Embedding generation for semantic search
|
||||
- Real-time status tracking via ElectricSQL
|
||||
|
||||
## Technical Architecture
|
||||
|
||||
### Backend Stack
|
||||
|
||||
The SurfSense backend is built with:
|
||||
|
||||
1. **FastAPI** for the REST API
|
||||
2. **PostgreSQL** with pgvector for vector storage
|
||||
3. **Celery** with Redis for background task processing
|
||||
4. **Docling/Unstructured** for document parsing (ETL)
|
||||
|
||||
### Processing Pipeline
|
||||
|
||||
Documents go through a multi-stage pipeline:
|
||||
|
||||
| Stage | Description |
|
||||
|-------|-------------|
|
||||
| Upload | File received via API endpoint |
|
||||
| Parsing | Content extracted using ETL service |
|
||||
| Chunking | Text split into semantic chunks |
|
||||
| Embedding | Vector representations generated |
|
||||
| Storage | Chunks stored with embeddings in pgvector |
|
||||
|
||||
## Code Example
|
||||
|
||||
```python
|
||||
async def process_document(file_path: str) -> Document:
|
||||
content = extract_content(file_path)
|
||||
chunks = create_chunks(content)
|
||||
embeddings = generate_embeddings(chunks)
|
||||
return store_document(chunks, embeddings)
|
||||
```
|
||||
|
||||
## Conclusion
|
||||
|
||||
This document serves as a test fixture to validate the complete document processing
|
||||
pipeline from upload through to chunk creation and embedding storage.
|
||||
BIN
surfsense_backend/tests/fixtures/sample.pdf
vendored
Normal file
BIN
surfsense_backend/tests/fixtures/sample.pdf
vendored
Normal file
Binary file not shown.
34
surfsense_backend/tests/fixtures/sample.txt
vendored
Normal file
34
surfsense_backend/tests/fixtures/sample.txt
vendored
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
SurfSense Document Upload Test
|
||||
|
||||
This is a sample text document used for end-to-end testing of the manual document
|
||||
upload pipeline in SurfSense. The document contains multiple paragraphs to ensure
|
||||
that the chunking system has enough content to work with.
|
||||
|
||||
Artificial Intelligence and Machine Learning
|
||||
|
||||
Artificial intelligence (AI) is a broad field of computer science concerned with
|
||||
building smart machines capable of performing tasks that typically require human
|
||||
intelligence. Machine learning is a subset of AI that enables systems to learn and
|
||||
improve from experience without being explicitly programmed.
|
||||
|
||||
Natural Language Processing
|
||||
|
||||
Natural language processing (NLP) is a subfield of linguistics, computer science,
|
||||
and artificial intelligence concerned with the interactions between computers and
|
||||
human language. Key applications include machine translation, sentiment analysis,
|
||||
text summarization, and question answering systems.
|
||||
|
||||
Vector Databases and Semantic Search
|
||||
|
||||
Vector databases store data as high-dimensional vectors, enabling efficient
|
||||
similarity search operations. When combined with embedding models, they power
|
||||
semantic search systems that understand the meaning behind queries rather than
|
||||
relying on exact keyword matches. This technology is fundamental to modern
|
||||
retrieval-augmented generation (RAG) systems.
|
||||
|
||||
Document Processing Pipelines
|
||||
|
||||
Modern document processing pipelines involve several stages: extraction, transformation,
|
||||
chunking, embedding generation, and storage. Each stage plays a critical role in
|
||||
converting raw documents into searchable, structured knowledge that can be retrieved
|
||||
and used by AI systems for accurate information retrieval and generation.
|
||||
0
surfsense_backend/tests/integration/__init__.py
Normal file
0
surfsense_backend/tests/integration/__init__.py
Normal file
172
surfsense_backend/tests/integration/conftest.py
Normal file
172
surfsense_backend/tests/integration/conftest.py
Normal file
|
|
@ -0,0 +1,172 @@
|
|||
import os
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from app.db import (
|
||||
Base,
|
||||
DocumentType,
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
SearchSpace,
|
||||
User,
|
||||
)
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
|
||||
_EMBEDDING_DIM = 1024 # must match the Vector() dimension used in DB column creation
|
||||
|
||||
_DEFAULT_TEST_DB = (
|
||||
"postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense_test"
|
||||
)
|
||||
TEST_DATABASE_URL = os.environ.get("TEST_DATABASE_URL", _DEFAULT_TEST_DB)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def async_engine():
|
||||
engine = create_async_engine(
|
||||
TEST_DATABASE_URL,
|
||||
poolclass=NullPool,
|
||||
echo=False,
|
||||
# Required for asyncpg + savepoints: disables prepared statement cache
|
||||
# to prevent "another operation is in progress" errors during savepoint rollbacks.
|
||||
connect_args={"prepared_statement_cache_size": 0},
|
||||
)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
# drop_all fails on circular FKs (new_chat_threads ↔ public_chat_snapshots).
|
||||
# DROP SCHEMA CASCADE handles this without needing topological sort.
|
||||
async with engine.begin() as conn:
|
||||
await conn.execute(text("DROP SCHEMA public CASCADE"))
|
||||
await conn.execute(text("CREATE SCHEMA public"))
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_session(async_engine) -> AsyncSession:
|
||||
# Bind the session to a connection that holds an outer transaction.
|
||||
# join_transaction_mode="create_savepoint" makes session.commit() release
|
||||
# a SAVEPOINT instead of committing the outer transaction, so the final
|
||||
# transaction.rollback() undoes everything — including commits made by the
|
||||
# service under test — leaving the DB clean for the next test.
|
||||
async with async_engine.connect() as conn:
|
||||
transaction = await conn.begin()
|
||||
async with AsyncSession(
|
||||
bind=conn,
|
||||
expire_on_commit=False,
|
||||
join_transaction_mode="create_savepoint",
|
||||
) as session:
|
||||
yield session
|
||||
await transaction.rollback()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_user(db_session: AsyncSession) -> User:
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="test@surfsense.net",
|
||||
hashed_password="hashed",
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
)
|
||||
db_session.add(user)
|
||||
await db_session.flush()
|
||||
return user
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_connector(
|
||||
db_session: AsyncSession, db_user: User, db_search_space: "SearchSpace"
|
||||
) -> SearchSourceConnector:
|
||||
connector = SearchSourceConnector(
|
||||
name="Test Connector",
|
||||
connector_type=SearchSourceConnectorType.CLICKUP_CONNECTOR,
|
||||
config={},
|
||||
search_space_id=db_search_space.id,
|
||||
user_id=db_user.id,
|
||||
)
|
||||
db_session.add(connector)
|
||||
await db_session.flush()
|
||||
return connector
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_search_space(db_session: AsyncSession, db_user: User) -> SearchSpace:
|
||||
space = SearchSpace(
|
||||
name="Test Space",
|
||||
user_id=db_user.id,
|
||||
)
|
||||
db_session.add(space)
|
||||
await db_session.flush()
|
||||
return space
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_summarize(monkeypatch) -> AsyncMock:
|
||||
mock = AsyncMock(return_value="Mocked summary.")
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.indexing_pipeline_service.summarize_document",
|
||||
mock,
|
||||
)
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_summarize_raises(monkeypatch) -> AsyncMock:
|
||||
mock = AsyncMock(side_effect=RuntimeError("LLM unavailable"))
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.indexing_pipeline_service.summarize_document",
|
||||
mock,
|
||||
)
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_embed_text(monkeypatch) -> MagicMock:
|
||||
mock = MagicMock(return_value=[0.1] * _EMBEDDING_DIM)
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.indexing_pipeline_service.embed_text",
|
||||
mock,
|
||||
)
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_chunk_text(monkeypatch) -> MagicMock:
|
||||
mock = MagicMock(return_value=["Test chunk content."])
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.indexing_pipeline_service.chunk_text",
|
||||
mock,
|
||||
)
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_connector_document(db_connector, db_user):
|
||||
"""Integration-scoped override: uses real DB connector and user IDs."""
|
||||
|
||||
def _make(**overrides):
|
||||
defaults = {
|
||||
"title": "Test Document",
|
||||
"source_markdown": "## Heading\n\nSome content.",
|
||||
"unique_id": "test-id-001",
|
||||
"document_type": DocumentType.CLICKUP_CONNECTOR,
|
||||
"search_space_id": db_connector.search_space_id,
|
||||
"connector_id": db_connector.id,
|
||||
"created_by_id": str(db_user.id),
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return ConnectorDocument(**defaults)
|
||||
|
||||
return _make
|
||||
|
|
@ -0,0 +1,99 @@
|
|||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.db import Chunk, Document, DocumentStatus
|
||||
from app.indexing_pipeline.adapters.file_upload_adapter import index_uploaded_file
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize", "patched_embed_text", "patched_chunk_text"
|
||||
)
|
||||
async def test_sets_status_ready(db_session, db_search_space, db_user, mocker):
|
||||
"""Document status is READY after successful indexing."""
|
||||
await index_uploaded_file(
|
||||
markdown_content="## Hello\n\nSome content.",
|
||||
filename="test.pdf",
|
||||
etl_service="UNSTRUCTURED",
|
||||
search_space_id=db_search_space.id,
|
||||
user_id=str(db_user.id),
|
||||
session=db_session,
|
||||
llm=mocker.Mock(),
|
||||
)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.search_space_id == db_search_space.id)
|
||||
)
|
||||
document = result.scalars().first()
|
||||
|
||||
assert DocumentStatus.is_state(document.status, DocumentStatus.READY)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize", "patched_embed_text", "patched_chunk_text"
|
||||
)
|
||||
async def test_content_is_summary(db_session, db_search_space, db_user, mocker):
|
||||
"""Document content is set to the LLM-generated summary."""
|
||||
await index_uploaded_file(
|
||||
markdown_content="## Hello\n\nSome content.",
|
||||
filename="test.pdf",
|
||||
etl_service="UNSTRUCTURED",
|
||||
search_space_id=db_search_space.id,
|
||||
user_id=str(db_user.id),
|
||||
session=db_session,
|
||||
llm=mocker.Mock(),
|
||||
)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.search_space_id == db_search_space.id)
|
||||
)
|
||||
document = result.scalars().first()
|
||||
|
||||
assert document.content == "Mocked summary."
|
||||
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize", "patched_embed_text", "patched_chunk_text"
|
||||
)
|
||||
async def test_chunks_written_to_db(db_session, db_search_space, db_user, mocker):
|
||||
"""Chunks derived from the source markdown are persisted in the DB."""
|
||||
await index_uploaded_file(
|
||||
markdown_content="## Hello\n\nSome content.",
|
||||
filename="test.pdf",
|
||||
etl_service="UNSTRUCTURED",
|
||||
search_space_id=db_search_space.id,
|
||||
user_id=str(db_user.id),
|
||||
session=db_session,
|
||||
llm=mocker.Mock(),
|
||||
)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.search_space_id == db_search_space.id)
|
||||
)
|
||||
document = result.scalars().first()
|
||||
|
||||
chunks_result = await db_session.execute(
|
||||
select(Chunk).filter(Chunk.document_id == document.id)
|
||||
)
|
||||
chunks = chunks_result.scalars().all()
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].content == "Test chunk content."
|
||||
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize_raises", "patched_embed_text", "patched_chunk_text"
|
||||
)
|
||||
async def test_raises_on_indexing_failure(db_session, db_search_space, db_user, mocker):
|
||||
"""RuntimeError is raised when the indexing step fails so the caller can fire a failure notification."""
|
||||
with pytest.raises(RuntimeError):
|
||||
await index_uploaded_file(
|
||||
markdown_content="## Hello\n\nSome content.",
|
||||
filename="test.pdf",
|
||||
etl_service="UNSTRUCTURED",
|
||||
search_space_id=db_search_space.id,
|
||||
user_id=str(db_user.id),
|
||||
session=db_session,
|
||||
llm=mocker.Mock(),
|
||||
)
|
||||
|
|
@ -0,0 +1,338 @@
|
|||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.db import Chunk, Document, DocumentStatus
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize", "patched_embed_text", "patched_chunk_text"
|
||||
)
|
||||
async def test_sets_status_ready(
|
||||
db_session,
|
||||
db_search_space,
|
||||
make_connector_document,
|
||||
mocker,
|
||||
):
|
||||
"""Document status is READY after successful indexing."""
|
||||
connector_doc = make_connector_document(search_space_id=db_search_space.id)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([connector_doc])
|
||||
document = prepared[0]
|
||||
document_id = document.id
|
||||
|
||||
await service.index(document, connector_doc, llm=mocker.Mock())
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.READY)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize", "patched_embed_text", "patched_chunk_text"
|
||||
)
|
||||
async def test_content_is_summary_when_should_summarize_true(
|
||||
db_session,
|
||||
db_search_space,
|
||||
make_connector_document,
|
||||
mocker,
|
||||
):
|
||||
"""Document content is set to the LLM-generated summary when should_summarize=True."""
|
||||
connector_doc = make_connector_document(search_space_id=db_search_space.id)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([connector_doc])
|
||||
document = prepared[0]
|
||||
document_id = document.id
|
||||
|
||||
await service.index(document, connector_doc, llm=mocker.Mock())
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert reloaded.content == "Mocked summary."
|
||||
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize", "patched_embed_text", "patched_chunk_text"
|
||||
)
|
||||
async def test_content_is_source_markdown_when_should_summarize_false(
|
||||
db_session,
|
||||
db_search_space,
|
||||
make_connector_document,
|
||||
):
|
||||
"""Document content is set to source_markdown verbatim when should_summarize=False."""
|
||||
connector_doc = make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
should_summarize=False,
|
||||
source_markdown="## Raw content",
|
||||
)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([connector_doc])
|
||||
document = prepared[0]
|
||||
document_id = document.id
|
||||
|
||||
await service.index(document, connector_doc, llm=None)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert reloaded.content == "## Raw content"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize", "patched_embed_text", "patched_chunk_text"
|
||||
)
|
||||
async def test_chunks_written_to_db(
|
||||
db_session,
|
||||
db_search_space,
|
||||
make_connector_document,
|
||||
mocker,
|
||||
):
|
||||
"""Chunks derived from source_markdown are persisted in the DB."""
|
||||
connector_doc = make_connector_document(search_space_id=db_search_space.id)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([connector_doc])
|
||||
document = prepared[0]
|
||||
document_id = document.id
|
||||
|
||||
await service.index(document, connector_doc, llm=mocker.Mock())
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Chunk).filter(Chunk.document_id == document_id)
|
||||
)
|
||||
chunks = result.scalars().all()
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].content == "Test chunk content."
|
||||
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize", "patched_embed_text", "patched_chunk_text"
|
||||
)
|
||||
async def test_embedding_written_to_db(
|
||||
db_session,
|
||||
db_search_space,
|
||||
make_connector_document,
|
||||
mocker,
|
||||
):
|
||||
"""Document embedding vector is persisted in the DB after indexing."""
|
||||
connector_doc = make_connector_document(search_space_id=db_search_space.id)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([connector_doc])
|
||||
document = prepared[0]
|
||||
document_id = document.id
|
||||
|
||||
await service.index(document, connector_doc, llm=mocker.Mock())
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert reloaded.embedding is not None
|
||||
assert len(reloaded.embedding) == 1024
|
||||
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize", "patched_embed_text", "patched_chunk_text"
|
||||
)
|
||||
async def test_updated_at_advances_after_indexing(
|
||||
db_session,
|
||||
db_search_space,
|
||||
make_connector_document,
|
||||
mocker,
|
||||
):
|
||||
"""updated_at timestamp is later after indexing than it was at prepare time."""
|
||||
connector_doc = make_connector_document(search_space_id=db_search_space.id)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([connector_doc])
|
||||
document = prepared[0]
|
||||
document_id = document.id
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
updated_at_pending = result.scalars().first().updated_at
|
||||
|
||||
await service.index(document, connector_doc, llm=mocker.Mock())
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
updated_at_ready = result.scalars().first().updated_at
|
||||
|
||||
assert updated_at_ready > updated_at_pending
|
||||
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize", "patched_embed_text", "patched_chunk_text"
|
||||
)
|
||||
async def test_no_llm_falls_back_to_source_markdown(
|
||||
db_session,
|
||||
db_search_space,
|
||||
make_connector_document,
|
||||
):
|
||||
"""When llm=None and no fallback_summary, content falls back to source_markdown."""
|
||||
connector_doc = make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
should_summarize=True,
|
||||
source_markdown="## Fallback content",
|
||||
)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([connector_doc])
|
||||
document = prepared[0]
|
||||
document_id = document.id
|
||||
|
||||
await service.index(document, connector_doc, llm=None)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.READY)
|
||||
assert reloaded.content == "## Fallback content"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize", "patched_embed_text", "patched_chunk_text"
|
||||
)
|
||||
async def test_fallback_summary_used_when_llm_unavailable(
|
||||
db_session,
|
||||
db_search_space,
|
||||
make_connector_document,
|
||||
):
|
||||
"""fallback_summary is used as content when llm=None and should_summarize=True."""
|
||||
connector_doc = make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
should_summarize=True,
|
||||
source_markdown="## Full raw content",
|
||||
fallback_summary="Short pre-built summary.",
|
||||
)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([connector_doc])
|
||||
document_id = prepared[0].id
|
||||
|
||||
await service.index(prepared[0], connector_doc, llm=None)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.READY)
|
||||
assert reloaded.content == "Short pre-built summary."
|
||||
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize", "patched_embed_text", "patched_chunk_text"
|
||||
)
|
||||
async def test_reindex_replaces_old_chunks(
|
||||
db_session,
|
||||
db_search_space,
|
||||
make_connector_document,
|
||||
mocker,
|
||||
):
|
||||
"""Re-indexing a document replaces its old chunks rather than appending."""
|
||||
connector_doc = make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
source_markdown="## v1",
|
||||
)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([connector_doc])
|
||||
document = prepared[0]
|
||||
document_id = document.id
|
||||
|
||||
await service.index(document, connector_doc, llm=mocker.Mock())
|
||||
|
||||
updated_doc = make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
source_markdown="## v2",
|
||||
)
|
||||
re_prepared = await service.prepare_for_indexing([updated_doc])
|
||||
await service.index(re_prepared[0], updated_doc, llm=mocker.Mock())
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Chunk).filter(Chunk.document_id == document_id)
|
||||
)
|
||||
chunks = result.scalars().all()
|
||||
|
||||
assert len(chunks) == 1
|
||||
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize_raises", "patched_embed_text", "patched_chunk_text"
|
||||
)
|
||||
async def test_llm_error_sets_status_failed(
|
||||
db_session,
|
||||
db_search_space,
|
||||
make_connector_document,
|
||||
mocker,
|
||||
):
|
||||
"""Document status is FAILED when the LLM raises during indexing."""
|
||||
connector_doc = make_connector_document(search_space_id=db_search_space.id)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([connector_doc])
|
||||
document = prepared[0]
|
||||
document_id = document.id
|
||||
|
||||
await service.index(document, connector_doc, llm=mocker.Mock())
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.FAILED)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize_raises", "patched_embed_text", "patched_chunk_text"
|
||||
)
|
||||
async def test_llm_error_leaves_no_partial_data(
|
||||
db_session,
|
||||
db_search_space,
|
||||
make_connector_document,
|
||||
mocker,
|
||||
):
|
||||
"""A failed indexing attempt leaves no partial embedding or chunks in the DB."""
|
||||
connector_doc = make_connector_document(search_space_id=db_search_space.id)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([connector_doc])
|
||||
document = prepared[0]
|
||||
document_id = document.id
|
||||
|
||||
await service.index(document, connector_doc, llm=mocker.Mock())
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert reloaded.embedding is None
|
||||
assert reloaded.content == "Pending..."
|
||||
|
||||
chunks_result = await db_session.execute(
|
||||
select(Chunk).filter(Chunk.document_id == document_id)
|
||||
)
|
||||
assert chunks_result.scalars().all() == []
|
||||
|
|
@ -0,0 +1,459 @@
|
|||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.db import Document, DocumentStatus
|
||||
from app.indexing_pipeline.document_hashing import (
|
||||
compute_content_hash as real_compute_content_hash,
|
||||
)
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
async def test_new_document_is_persisted_with_pending_status(
|
||||
db_session, db_search_space, make_connector_document
|
||||
):
|
||||
"""A new document is created in the DB with PENDING status and correct markdown."""
|
||||
doc = make_connector_document(search_space_id=db_search_space.id)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
results = await service.prepare_for_indexing([doc])
|
||||
|
||||
assert len(results) == 1
|
||||
document_id = results[0].id
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert reloaded is not None
|
||||
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.PENDING)
|
||||
assert reloaded.source_markdown == doc.source_markdown
|
||||
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize", "patched_embed_text", "patched_chunk_text"
|
||||
)
|
||||
async def test_unchanged_ready_document_is_skipped(
|
||||
db_session,
|
||||
db_search_space,
|
||||
make_connector_document,
|
||||
mocker,
|
||||
):
|
||||
"""A READY document with unchanged content is not returned for re-indexing."""
|
||||
doc = make_connector_document(search_space_id=db_search_space.id)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
# Index fully so the document reaches ready state
|
||||
prepared = await service.prepare_for_indexing([doc])
|
||||
await service.index(prepared[0], doc, llm=mocker.Mock())
|
||||
|
||||
# Same content on the next run — a ready document must be skipped
|
||||
results = await service.prepare_for_indexing([doc])
|
||||
|
||||
assert results == []
|
||||
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize", "patched_embed_text", "patched_chunk_text"
|
||||
)
|
||||
async def test_title_only_change_updates_title_in_db(
|
||||
db_session,
|
||||
db_search_space,
|
||||
make_connector_document,
|
||||
mocker,
|
||||
):
|
||||
"""A title-only change updates the DB title without re-queuing the document."""
|
||||
original = make_connector_document(
|
||||
search_space_id=db_search_space.id, title="Original Title"
|
||||
)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([original])
|
||||
document_id = prepared[0].id
|
||||
await service.index(prepared[0], original, llm=mocker.Mock())
|
||||
|
||||
renamed = make_connector_document(
|
||||
search_space_id=db_search_space.id, title="Updated Title"
|
||||
)
|
||||
results = await service.prepare_for_indexing([renamed])
|
||||
|
||||
assert results == []
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert reloaded.title == "Updated Title"
|
||||
|
||||
|
||||
async def test_changed_content_is_returned_for_reprocessing(
|
||||
db_session, db_search_space, make_connector_document
|
||||
):
|
||||
"""A document with changed content is returned for re-indexing with updated markdown."""
|
||||
original = make_connector_document(
|
||||
search_space_id=db_search_space.id, source_markdown="## v1"
|
||||
)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
first = await service.prepare_for_indexing([original])
|
||||
original_id = first[0].id
|
||||
|
||||
updated = make_connector_document(
|
||||
search_space_id=db_search_space.id, source_markdown="## v2"
|
||||
)
|
||||
results = await service.prepare_for_indexing([updated])
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].id == original_id
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.id == original_id)
|
||||
)
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert reloaded.source_markdown == "## v2"
|
||||
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.PENDING)
|
||||
|
||||
|
||||
async def test_all_documents_in_batch_are_persisted(
|
||||
db_session, db_search_space, make_connector_document
|
||||
):
|
||||
"""All documents in a batch are persisted and returned."""
|
||||
docs = [
|
||||
make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
unique_id="id-1",
|
||||
title="Doc 1",
|
||||
source_markdown="## Content 1",
|
||||
),
|
||||
make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
unique_id="id-2",
|
||||
title="Doc 2",
|
||||
source_markdown="## Content 2",
|
||||
),
|
||||
make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
unique_id="id-3",
|
||||
title="Doc 3",
|
||||
source_markdown="## Content 3",
|
||||
),
|
||||
]
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
results = await service.prepare_for_indexing(docs)
|
||||
|
||||
assert len(results) == 3
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.search_space_id == db_search_space.id)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
|
||||
assert len(rows) == 3
|
||||
|
||||
|
||||
async def test_duplicate_in_batch_is_persisted_once(
|
||||
db_session, db_search_space, make_connector_document
|
||||
):
|
||||
"""The same document passed twice in a batch is only persisted once."""
|
||||
doc = make_connector_document(search_space_id=db_search_space.id)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
results = await service.prepare_for_indexing([doc, doc])
|
||||
|
||||
assert len(results) == 1
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.search_space_id == db_search_space.id)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
|
||||
assert len(rows) == 1
|
||||
|
||||
|
||||
async def test_created_by_id_is_persisted(
|
||||
db_session, db_user, db_search_space, make_connector_document
|
||||
):
|
||||
"""created_by_id from the connector document is persisted on the DB row."""
|
||||
doc = make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
created_by_id=str(db_user.id),
|
||||
)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
results = await service.prepare_for_indexing([doc])
|
||||
document_id = results[0].id
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert str(reloaded.created_by_id) == str(db_user.id)
|
||||
|
||||
|
||||
async def test_metadata_is_updated_when_content_changes(
|
||||
db_session, db_search_space, make_connector_document
|
||||
):
|
||||
"""document_metadata is overwritten with the latest metadata when content changes."""
|
||||
original = make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
source_markdown="## v1",
|
||||
metadata={"status": "in_progress"},
|
||||
)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
first = await service.prepare_for_indexing([original])
|
||||
document_id = first[0].id
|
||||
|
||||
updated = make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
source_markdown="## v2",
|
||||
metadata={"status": "done"},
|
||||
)
|
||||
await service.prepare_for_indexing([updated])
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert reloaded.document_metadata == {"status": "done"}
|
||||
|
||||
|
||||
async def test_updated_at_advances_when_title_only_changes(
|
||||
db_session, db_search_space, make_connector_document
|
||||
):
|
||||
"""updated_at advances even when only the title changes."""
|
||||
original = make_connector_document(
|
||||
search_space_id=db_search_space.id, title="Old Title"
|
||||
)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
first = await service.prepare_for_indexing([original])
|
||||
document_id = first[0].id
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
updated_at_v1 = result.scalars().first().updated_at
|
||||
|
||||
renamed = make_connector_document(
|
||||
search_space_id=db_search_space.id, title="New Title"
|
||||
)
|
||||
await service.prepare_for_indexing([renamed])
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
updated_at_v2 = result.scalars().first().updated_at
|
||||
|
||||
assert updated_at_v2 > updated_at_v1
|
||||
|
||||
|
||||
async def test_updated_at_advances_when_content_changes(
|
||||
db_session, db_search_space, make_connector_document
|
||||
):
|
||||
"""updated_at advances when document content changes."""
|
||||
original = make_connector_document(
|
||||
search_space_id=db_search_space.id, source_markdown="## v1"
|
||||
)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
first = await service.prepare_for_indexing([original])
|
||||
document_id = first[0].id
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
updated_at_v1 = result.scalars().first().updated_at
|
||||
|
||||
updated = make_connector_document(
|
||||
search_space_id=db_search_space.id, source_markdown="## v2"
|
||||
)
|
||||
await service.prepare_for_indexing([updated])
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
updated_at_v2 = result.scalars().first().updated_at
|
||||
|
||||
assert updated_at_v2 > updated_at_v1
|
||||
|
||||
|
||||
async def test_same_content_from_different_source_skipped_in_single_batch(
|
||||
db_session, db_search_space, make_connector_document
|
||||
):
|
||||
"""Two documents with identical content in the same batch result in only one being persisted."""
|
||||
first = make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
unique_id="source-a",
|
||||
source_markdown="## Shared content",
|
||||
)
|
||||
second = make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
unique_id="source-b",
|
||||
source_markdown="## Shared content",
|
||||
)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
results = await service.prepare_for_indexing([first, second])
|
||||
|
||||
assert len(results) == 1
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.search_space_id == db_search_space.id)
|
||||
)
|
||||
assert len(result.scalars().all()) == 1
|
||||
|
||||
|
||||
async def test_same_content_from_different_source_is_skipped(
|
||||
db_session, db_search_space, make_connector_document
|
||||
):
|
||||
"""A document with content identical to an already-indexed document is skipped."""
|
||||
first = make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
unique_id="source-a",
|
||||
source_markdown="## Shared content",
|
||||
)
|
||||
second = make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
unique_id="source-b",
|
||||
source_markdown="## Shared content",
|
||||
)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
await service.prepare_for_indexing([first])
|
||||
results = await service.prepare_for_indexing([second])
|
||||
|
||||
assert results == []
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.search_space_id == db_search_space.id)
|
||||
)
|
||||
assert len(result.scalars().all()) == 1
|
||||
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize_raises", "patched_embed_text", "patched_chunk_text"
|
||||
)
|
||||
async def test_failed_document_with_unchanged_content_is_requeued(
|
||||
db_session,
|
||||
db_search_space,
|
||||
make_connector_document,
|
||||
mocker,
|
||||
):
|
||||
"""A FAILED document with unchanged content is re-queued as PENDING on the next run."""
|
||||
doc = make_connector_document(search_space_id=db_search_space.id)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
# First run: document is created and indexing crashes → status = failed
|
||||
prepared = await service.prepare_for_indexing([doc])
|
||||
document_id = prepared[0].id
|
||||
await service.index(prepared[0], doc, llm=mocker.Mock())
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
assert DocumentStatus.is_state(
|
||||
result.scalars().first().status, DocumentStatus.FAILED
|
||||
)
|
||||
|
||||
# Next run: same content, pipeline must re-queue the failed document
|
||||
results = await service.prepare_for_indexing([doc])
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].id == document_id
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
assert DocumentStatus.is_state(
|
||||
result.scalars().first().status, DocumentStatus.PENDING
|
||||
)
|
||||
|
||||
|
||||
async def test_title_and_content_change_updates_both_and_returns_document(
|
||||
db_session, db_search_space, make_connector_document
|
||||
):
|
||||
"""When both title and content change, both are updated and the document is returned for re-indexing."""
|
||||
original = make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
title="Original Title",
|
||||
source_markdown="## v1",
|
||||
)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
first = await service.prepare_for_indexing([original])
|
||||
original_id = first[0].id
|
||||
|
||||
updated = make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
title="Updated Title",
|
||||
source_markdown="## v2",
|
||||
)
|
||||
results = await service.prepare_for_indexing([updated])
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].id == original_id
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.id == original_id)
|
||||
)
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert reloaded.title == "Updated Title"
|
||||
assert reloaded.source_markdown == "## v2"
|
||||
|
||||
|
||||
async def test_one_bad_document_in_batch_does_not_prevent_others_from_being_persisted(
|
||||
db_session,
|
||||
db_search_space,
|
||||
make_connector_document,
|
||||
monkeypatch,
|
||||
):
|
||||
"""
|
||||
A per-document error during prepare_for_indexing must be isolated.
|
||||
The two valid documents around the failing one must still be persisted.
|
||||
"""
|
||||
docs = [
|
||||
make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
unique_id="good-1",
|
||||
source_markdown="## Good doc 1",
|
||||
),
|
||||
make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
unique_id="will-fail",
|
||||
source_markdown="## Bad doc",
|
||||
),
|
||||
make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
unique_id="good-2",
|
||||
source_markdown="## Good doc 2",
|
||||
),
|
||||
]
|
||||
|
||||
def compute_content_hash_with_error(doc):
|
||||
if doc.unique_id == "will-fail":
|
||||
raise RuntimeError("Simulated per-document failure")
|
||||
return real_compute_content_hash(doc)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.indexing_pipeline_service.compute_content_hash",
|
||||
compute_content_hash_with_error,
|
||||
)
|
||||
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
results = await service.prepare_for_indexing(docs)
|
||||
|
||||
assert len(results) == 2
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.search_space_id == db_search_space.id)
|
||||
)
|
||||
assert len(result.scalars().all()) == 2
|
||||
0
surfsense_backend/tests/unit/__init__.py
Normal file
0
surfsense_backend/tests/unit/__init__.py
Normal file
0
surfsense_backend/tests/unit/adapters/__init__.py
Normal file
0
surfsense_backend/tests/unit/adapters/__init__.py
Normal file
38
surfsense_backend/tests/unit/indexing_pipeline/conftest.py
Normal file
38
surfsense_backend/tests/unit/indexing_pipeline/conftest.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_summarizer_chain(monkeypatch):
|
||||
chain = MagicMock()
|
||||
chain.ainvoke = AsyncMock(return_value=MagicMock(content="The summary."))
|
||||
|
||||
template = MagicMock()
|
||||
template.__or__ = MagicMock(return_value=chain)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.document_summarizer.SUMMARY_PROMPT_TEMPLATE",
|
||||
template,
|
||||
)
|
||||
return chain
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_chunker_instance(monkeypatch):
|
||||
mock = MagicMock()
|
||||
mock.chunk.return_value = [MagicMock(text="prose chunk")]
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.document_chunker.config.chunker_instance", mock
|
||||
)
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_code_chunker_instance(monkeypatch):
|
||||
mock = MagicMock()
|
||||
mock.chunk.return_value = [MagicMock(text="code chunk")]
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.document_chunker.config.code_chunker_instance", mock
|
||||
)
|
||||
return mock
|
||||
|
|
@ -0,0 +1,112 @@
|
|||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.db import DocumentType
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
|
||||
|
||||
def test_valid_document_created_with_required_fields():
|
||||
"""All optional fields default correctly when only required fields are supplied."""
|
||||
doc = ConnectorDocument(
|
||||
title="Task",
|
||||
source_markdown="## Task\n\nSome content.",
|
||||
unique_id="task-1",
|
||||
document_type=DocumentType.CLICKUP_CONNECTOR,
|
||||
search_space_id=1,
|
||||
connector_id=42,
|
||||
created_by_id="00000000-0000-0000-0000-000000000001",
|
||||
)
|
||||
assert doc.should_summarize is True
|
||||
assert doc.should_use_code_chunker is False
|
||||
assert doc.metadata == {}
|
||||
assert doc.connector_id == 42
|
||||
assert doc.created_by_id == "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
|
||||
def test_omitting_created_by_id_raises():
|
||||
"""Omitting created_by_id raises a validation error."""
|
||||
with pytest.raises(ValidationError):
|
||||
ConnectorDocument(
|
||||
title="Task",
|
||||
source_markdown="## Content",
|
||||
unique_id="task-1",
|
||||
document_type=DocumentType.CLICKUP_CONNECTOR,
|
||||
search_space_id=1,
|
||||
connector_id=42,
|
||||
)
|
||||
|
||||
|
||||
def test_empty_source_markdown_raises():
|
||||
"""Empty source_markdown raises a validation error."""
|
||||
with pytest.raises(ValidationError):
|
||||
ConnectorDocument(
|
||||
title="Task",
|
||||
source_markdown="",
|
||||
unique_id="task-1",
|
||||
document_type=DocumentType.CLICKUP_CONNECTOR,
|
||||
search_space_id=1,
|
||||
)
|
||||
|
||||
|
||||
def test_whitespace_only_source_markdown_raises():
|
||||
"""Whitespace-only source_markdown raises a validation error."""
|
||||
with pytest.raises(ValidationError):
|
||||
ConnectorDocument(
|
||||
title="Task",
|
||||
source_markdown=" \n\t ",
|
||||
unique_id="task-1",
|
||||
document_type=DocumentType.CLICKUP_CONNECTOR,
|
||||
search_space_id=1,
|
||||
)
|
||||
|
||||
|
||||
def test_empty_title_raises():
|
||||
"""Empty title raises a validation error."""
|
||||
with pytest.raises(ValidationError):
|
||||
ConnectorDocument(
|
||||
title="",
|
||||
source_markdown="## Content",
|
||||
unique_id="task-1",
|
||||
document_type=DocumentType.CLICKUP_CONNECTOR,
|
||||
search_space_id=1,
|
||||
)
|
||||
|
||||
|
||||
def test_empty_created_by_id_raises():
|
||||
"""Empty created_by_id raises a validation error."""
|
||||
with pytest.raises(ValidationError):
|
||||
ConnectorDocument(
|
||||
title="Task",
|
||||
source_markdown="## Content",
|
||||
unique_id="task-1",
|
||||
document_type=DocumentType.CLICKUP_CONNECTOR,
|
||||
search_space_id=1,
|
||||
connector_id=42,
|
||||
created_by_id="",
|
||||
)
|
||||
|
||||
|
||||
def test_zero_search_space_id_raises():
|
||||
"""search_space_id of zero raises a validation error."""
|
||||
with pytest.raises(ValidationError):
|
||||
ConnectorDocument(
|
||||
title="Task",
|
||||
source_markdown="## Content",
|
||||
unique_id="task-1",
|
||||
document_type=DocumentType.CLICKUP_CONNECTOR,
|
||||
search_space_id=0,
|
||||
connector_id=42,
|
||||
created_by_id="00000000-0000-0000-0000-000000000001",
|
||||
)
|
||||
|
||||
|
||||
def test_empty_unique_id_raises():
|
||||
"""Empty unique_id raises a validation error."""
|
||||
with pytest.raises(ValidationError):
|
||||
ConnectorDocument(
|
||||
title="Task",
|
||||
source_markdown="## Content",
|
||||
unique_id="",
|
||||
document_type=DocumentType.CLICKUP_CONNECTOR,
|
||||
search_space_id=1,
|
||||
)
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
import pytest
|
||||
|
||||
from app.indexing_pipeline.document_chunker import chunk_text
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_chunker_instance", "patched_code_chunker_instance")
|
||||
def test_uses_code_chunker_when_flag_is_true():
|
||||
"""Code chunker is selected when use_code_chunker=True."""
|
||||
result = chunk_text("def foo(): pass", use_code_chunker=True)
|
||||
|
||||
assert result == ["code chunk"]
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_chunker_instance", "patched_code_chunker_instance")
|
||||
def test_uses_default_chunker_when_flag_is_false():
|
||||
"""Default prose chunker is selected when use_code_chunker=False."""
|
||||
result = chunk_text("Some prose text.", use_code_chunker=False)
|
||||
|
||||
assert result == ["prose chunk"]
|
||||
|
|
@ -0,0 +1,63 @@
|
|||
import pytest
|
||||
|
||||
from app.db import DocumentType
|
||||
from app.indexing_pipeline.document_hashing import (
|
||||
compute_content_hash,
|
||||
compute_unique_identifier_hash,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_different_unique_id_produces_different_hash(make_connector_document):
|
||||
"""Two documents with different unique_ids produce different identifier hashes."""
|
||||
doc_a = make_connector_document(unique_id="id-001")
|
||||
doc_b = make_connector_document(unique_id="id-002")
|
||||
assert compute_unique_identifier_hash(doc_a) != compute_unique_identifier_hash(
|
||||
doc_b
|
||||
)
|
||||
|
||||
|
||||
def test_different_search_space_produces_different_identifier_hash(
|
||||
make_connector_document,
|
||||
):
|
||||
"""Same document in different search spaces produces different identifier hashes."""
|
||||
doc_a = make_connector_document(search_space_id=1)
|
||||
doc_b = make_connector_document(search_space_id=2)
|
||||
assert compute_unique_identifier_hash(doc_a) != compute_unique_identifier_hash(
|
||||
doc_b
|
||||
)
|
||||
|
||||
|
||||
def test_different_document_type_produces_different_identifier_hash(
|
||||
make_connector_document,
|
||||
):
|
||||
"""Same unique_id with different document types produces different identifier hashes."""
|
||||
doc_a = make_connector_document(document_type=DocumentType.CLICKUP_CONNECTOR)
|
||||
doc_b = make_connector_document(document_type=DocumentType.NOTION_CONNECTOR)
|
||||
assert compute_unique_identifier_hash(doc_a) != compute_unique_identifier_hash(
|
||||
doc_b
|
||||
)
|
||||
|
||||
|
||||
def test_same_content_same_space_produces_same_content_hash(make_connector_document):
|
||||
"""Identical content in the same search space always produces the same content hash."""
|
||||
doc_a = make_connector_document(source_markdown="Hello world", search_space_id=1)
|
||||
doc_b = make_connector_document(source_markdown="Hello world", search_space_id=1)
|
||||
assert compute_content_hash(doc_a) == compute_content_hash(doc_b)
|
||||
|
||||
|
||||
def test_same_content_different_space_produces_different_content_hash(
|
||||
make_connector_document,
|
||||
):
|
||||
"""Identical content in different search spaces produces different content hashes."""
|
||||
doc_a = make_connector_document(source_markdown="Hello world", search_space_id=1)
|
||||
doc_b = make_connector_document(source_markdown="Hello world", search_space_id=2)
|
||||
assert compute_content_hash(doc_a) != compute_content_hash(doc_b)
|
||||
|
||||
|
||||
def test_different_content_produces_different_content_hash(make_connector_document):
|
||||
"""Different source markdown produces different content hashes."""
|
||||
doc_a = make_connector_document(source_markdown="Original content")
|
||||
doc_b = make_connector_document(source_markdown="Updated content")
|
||||
assert compute_content_hash(doc_a) != compute_content_hash(doc_b)
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.indexing_pipeline.document_summarizer import summarize_document
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarizer_chain")
|
||||
async def test_without_metadata_returns_raw_summary():
|
||||
"""Summarizer returns the LLM output directly when no metadata is provided."""
|
||||
result = await summarize_document("# Content", llm=MagicMock(model="gpt-4"))
|
||||
|
||||
assert result == "The summary."
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarizer_chain")
|
||||
async def test_with_metadata_includes_metadata_values_in_output():
|
||||
"""Non-empty metadata values are prepended to the summary output."""
|
||||
result = await summarize_document(
|
||||
"# Content",
|
||||
llm=MagicMock(model="gpt-4"),
|
||||
metadata={"author": "Alice", "source": "Notion"},
|
||||
)
|
||||
|
||||
assert "Alice" in result
|
||||
assert "Notion" in result
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarizer_chain")
|
||||
async def test_with_metadata_omits_empty_fields_from_output():
|
||||
"""Empty metadata fields are omitted from the summary output."""
|
||||
result = await summarize_document(
|
||||
"# Content",
|
||||
llm=MagicMock(model="gpt-4"),
|
||||
metadata={"author": "Alice", "description": ""},
|
||||
)
|
||||
|
||||
assert "Alice" in result
|
||||
assert "description" not in result.lower()
|
||||
0
surfsense_backend/tests/utils/__init__.py
Normal file
0
surfsense_backend/tests/utils/__init__.py
Normal file
224
surfsense_backend/tests/utils/helpers.py
Normal file
224
surfsense_backend/tests/utils/helpers.py
Normal file
|
|
@ -0,0 +1,224 @@
|
|||
"""Shared test helpers for authentication, polling, and cleanup."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
|
||||
FIXTURES_DIR = Path(__file__).resolve().parent.parent / "fixtures"
|
||||
|
||||
BACKEND_URL = os.environ.get("TEST_BACKEND_URL", "http://localhost:8000")
|
||||
TEST_EMAIL = os.environ.get("TEST_USER_EMAIL", "testuser@surfsense.com")
|
||||
TEST_PASSWORD = os.environ.get("TEST_USER_PASSWORD", "testpassword123")
|
||||
|
||||
|
||||
async def get_auth_token(client: httpx.AsyncClient) -> str:
|
||||
"""Log in and return a Bearer JWT token, registering the user first if needed."""
|
||||
response = await client.post(
|
||||
"/auth/jwt/login",
|
||||
data={"username": TEST_EMAIL, "password": TEST_PASSWORD},
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
)
|
||||
if response.status_code == 200:
|
||||
return response.json()["access_token"]
|
||||
|
||||
reg_response = await client.post(
|
||||
"/auth/register",
|
||||
json={"email": TEST_EMAIL, "password": TEST_PASSWORD},
|
||||
)
|
||||
assert reg_response.status_code == 201, (
|
||||
f"Registration failed ({reg_response.status_code}): {reg_response.text}"
|
||||
)
|
||||
|
||||
response = await client.post(
|
||||
"/auth/jwt/login",
|
||||
data={"username": TEST_EMAIL, "password": TEST_PASSWORD},
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
)
|
||||
assert response.status_code == 200, (
|
||||
f"Login after registration failed ({response.status_code}): {response.text}"
|
||||
)
|
||||
return response.json()["access_token"]
|
||||
|
||||
|
||||
async def get_search_space_id(client: httpx.AsyncClient, token: str) -> int:
|
||||
"""Fetch the first search space owned by the test user."""
|
||||
resp = await client.get(
|
||||
"/api/v1/searchspaces",
|
||||
headers=auth_headers(token),
|
||||
)
|
||||
assert resp.status_code == 200, (
|
||||
f"Failed to list search spaces ({resp.status_code}): {resp.text}"
|
||||
)
|
||||
spaces = resp.json()
|
||||
assert len(spaces) > 0, "No search spaces found for test user"
|
||||
return spaces[0]["id"]
|
||||
|
||||
|
||||
def auth_headers(token: str) -> dict[str, str]:
|
||||
"""Return Authorization header dict for a Bearer token."""
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
async def upload_file(
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
fixture_name: str,
|
||||
*,
|
||||
search_space_id: int,
|
||||
filename_override: str | None = None,
|
||||
) -> httpx.Response:
|
||||
"""Upload a single fixture file and return the raw response."""
|
||||
file_path = FIXTURES_DIR / fixture_name
|
||||
upload_name = filename_override or fixture_name
|
||||
with open(file_path, "rb") as f:
|
||||
return await client.post(
|
||||
"/api/v1/documents/fileupload",
|
||||
headers=headers,
|
||||
files={"files": (upload_name, f)},
|
||||
data={"search_space_id": str(search_space_id)},
|
||||
)
|
||||
|
||||
|
||||
async def upload_multiple_files(
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
fixture_names: list[str],
|
||||
*,
|
||||
search_space_id: int,
|
||||
) -> httpx.Response:
|
||||
"""Upload multiple fixture files in a single request."""
|
||||
files = []
|
||||
open_handles = []
|
||||
try:
|
||||
for name in fixture_names:
|
||||
fh = open(FIXTURES_DIR / name, "rb") # noqa: SIM115
|
||||
open_handles.append(fh)
|
||||
files.append(("files", (name, fh)))
|
||||
|
||||
return await client.post(
|
||||
"/api/v1/documents/fileupload",
|
||||
headers=headers,
|
||||
files=files,
|
||||
data={"search_space_id": str(search_space_id)},
|
||||
)
|
||||
finally:
|
||||
for fh in open_handles:
|
||||
fh.close()
|
||||
|
||||
|
||||
async def poll_document_status(
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
document_ids: list[int],
|
||||
*,
|
||||
search_space_id: int,
|
||||
timeout: float = 180.0,
|
||||
interval: float = 3.0,
|
||||
) -> dict[int, dict]:
|
||||
"""
|
||||
Poll ``GET /api/v1/documents/status`` until every document reaches a
|
||||
terminal state (``ready`` or ``failed``) or *timeout* seconds elapse.
|
||||
|
||||
Returns a mapping of ``{document_id: status_item_dict}``.
|
||||
|
||||
Retries on transient transport errors until timeout.
|
||||
"""
|
||||
ids_param = ",".join(str(d) for d in document_ids)
|
||||
terminal_states = {"ready", "failed"}
|
||||
elapsed = 0.0
|
||||
items: dict[int, dict] = {}
|
||||
last_transport_error: Exception | None = None
|
||||
|
||||
while elapsed < timeout:
|
||||
try:
|
||||
resp = await client.get(
|
||||
"/api/v1/documents/status",
|
||||
headers=headers,
|
||||
params={
|
||||
"search_space_id": search_space_id,
|
||||
"document_ids": ids_param,
|
||||
},
|
||||
)
|
||||
except (httpx.ReadError, httpx.ConnectError, httpx.TimeoutException) as exc:
|
||||
last_transport_error = exc
|
||||
await asyncio.sleep(interval)
|
||||
elapsed += interval
|
||||
continue
|
||||
|
||||
assert resp.status_code == 200, (
|
||||
f"Status poll failed ({resp.status_code}): {resp.text}"
|
||||
)
|
||||
|
||||
items = {item["id"]: item for item in resp.json()["items"]}
|
||||
if all(
|
||||
items.get(did, {}).get("status", {}).get("state") in terminal_states
|
||||
for did in document_ids
|
||||
):
|
||||
return items
|
||||
|
||||
await asyncio.sleep(interval)
|
||||
elapsed += interval
|
||||
|
||||
raise TimeoutError(
|
||||
f"Documents {document_ids} did not reach terminal state within {timeout}s. "
|
||||
f"Last status: {items}. "
|
||||
f"Last transport error: {last_transport_error!r}"
|
||||
)
|
||||
|
||||
|
||||
async def get_document(
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
document_id: int,
|
||||
) -> dict:
|
||||
"""Fetch a single document by ID."""
|
||||
resp = await client.get(
|
||||
f"/api/v1/documents/{document_id}",
|
||||
headers=headers,
|
||||
)
|
||||
assert resp.status_code == 200, (
|
||||
f"GET document {document_id} failed ({resp.status_code}): {resp.text}"
|
||||
)
|
||||
return resp.json()
|
||||
|
||||
|
||||
async def delete_document(
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
document_id: int,
|
||||
) -> httpx.Response:
|
||||
"""Delete a document by ID, returning the raw response."""
|
||||
return await client.delete(
|
||||
f"/api/v1/documents/{document_id}",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
|
||||
async def get_notifications(
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
*,
|
||||
type_filter: str | None = None,
|
||||
search_space_id: int | None = None,
|
||||
limit: int = 50,
|
||||
) -> list[dict]:
|
||||
"""Fetch notifications for the authenticated user, optionally filtered by type."""
|
||||
params: dict[str, str | int] = {"limit": limit}
|
||||
if type_filter:
|
||||
params["type"] = type_filter
|
||||
if search_space_id is not None:
|
||||
params["search_space_id"] = search_space_id
|
||||
|
||||
resp = await client.get(
|
||||
"/api/v1/notifications",
|
||||
headers=headers,
|
||||
params=params,
|
||||
)
|
||||
assert resp.status_code == 200, (
|
||||
f"GET notifications failed ({resp.status_code}): {resp.text}"
|
||||
)
|
||||
return resp.json()["items"]
|
||||
8889
surfsense_backend/uv.lock
generated
8889
surfsense_backend/uv.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -5,13 +5,11 @@ import { FeaturesBentoGrid } from "@/components/homepage/features-bento-grid";
|
|||
import { FeaturesCards } from "@/components/homepage/features-card";
|
||||
import { HeroSection } from "@/components/homepage/hero-section";
|
||||
import ExternalIntegrations from "@/components/homepage/integrations";
|
||||
import { UseCasesGrid } from "@/components/homepage/use-cases-grid";
|
||||
|
||||
export default function HomePage() {
|
||||
return (
|
||||
<main className="min-h-screen bg-gradient-to-b from-gray-50 to-gray-100 text-gray-900 dark:from-black dark:to-gray-900 dark:text-white">
|
||||
<HeroSection />
|
||||
<UseCasesGrid />
|
||||
<FeaturesCards />
|
||||
<FeaturesBentoGrid />
|
||||
<ExternalIntegrations />
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ export function DocumentTypeChip({ type, className }: { type: string; className?
|
|||
|
||||
const chip = (
|
||||
<span
|
||||
className={`inline-flex items-center gap-1.5 rounded-full bg-accent/80 px-2.5 py-1 text-xs font-medium text-accent-foreground/70 shadow-sm max-w-full overflow-hidden ${className ?? ""}`}
|
||||
className={`inline-flex items-center gap-1.5 rounded-full bg-accent/80 px-2.5 py-1 text-xs font-medium text-accent-foreground shadow-sm max-w-full overflow-hidden ${className ?? ""}`}
|
||||
>
|
||||
<span className="flex-shrink-0">{icon}</span>
|
||||
<span ref={textRef} className="truncate min-w-0">
|
||||
|
|
|
|||
|
|
@ -3,13 +3,13 @@
|
|||
import { useSetAtom } from "jotai";
|
||||
import {
|
||||
CircleAlert,
|
||||
CircleX,
|
||||
FilePlus2,
|
||||
FileType,
|
||||
ListFilter,
|
||||
Search,
|
||||
SlidersHorizontal,
|
||||
Trash,
|
||||
Upload,
|
||||
X,
|
||||
} from "lucide-react";
|
||||
import { motion } from "motion/react";
|
||||
import { useTranslations } from "next-intl";
|
||||
|
|
@ -81,7 +81,7 @@ export function DocumentsFilters({
|
|||
|
||||
return (
|
||||
<motion.div
|
||||
className="flex flex-col gap-4"
|
||||
className="flex flex-col gap-4 select-none"
|
||||
initial={{ opacity: 0, y: 10 }}
|
||||
animate={{ opacity: 1, y: 0 }}
|
||||
transition={{ type: "spring", stiffness: 300, damping: 30, delay: 0.1 }}
|
||||
|
|
@ -96,7 +96,7 @@ export function DocumentsFilters({
|
|||
size="sm"
|
||||
className="h-9 gap-2 bg-white text-gray-700 border-white hover:bg-gray-50 dark:bg-white dark:text-gray-800 dark:hover:bg-gray-100"
|
||||
>
|
||||
<FilePlus2 size={16} />
|
||||
<Upload size={16} />
|
||||
<span>Upload documents</span>
|
||||
</Button>
|
||||
<Button
|
||||
|
|
@ -126,7 +126,7 @@ export function DocumentsFilters({
|
|||
<Input
|
||||
id={`${id}-input`}
|
||||
ref={inputRef}
|
||||
className="peer h-9 w-full pl-9 pr-9 text-sm bg-background border-border/60 focus-visible:ring-1 focus-visible:ring-ring/30"
|
||||
className="peer h-9 w-full pl-9 pr-9 text-sm bg-background border-border/60 focus-visible:ring-1 focus-visible:ring-ring/30 select-none focus:select-text"
|
||||
value={searchValue}
|
||||
onChange={(e) => onSearch(e.target.value)}
|
||||
placeholder="Filter by title"
|
||||
|
|
@ -135,7 +135,7 @@ export function DocumentsFilters({
|
|||
/>
|
||||
{Boolean(searchValue) && (
|
||||
<motion.button
|
||||
className="absolute inset-y-0 right-0 flex h-full w-9 items-center justify-center rounded-r-md text-muted-foreground/60 hover:text-foreground transition-colors"
|
||||
className="absolute inset-y-0 right-0 flex h-full w-9 items-center justify-center rounded-r-md text-muted-foreground hover:text-foreground transition-colors"
|
||||
aria-label="Clear filter"
|
||||
onClick={() => {
|
||||
onSearch("");
|
||||
|
|
@ -147,7 +147,7 @@ export function DocumentsFilters({
|
|||
whileHover={{ scale: 1.1 }}
|
||||
whileTap={{ scale: 0.9 }}
|
||||
>
|
||||
<CircleX size={14} strokeWidth={2} aria-hidden="true" />
|
||||
<X size={14} strokeWidth={2} aria-hidden="true" />
|
||||
</motion.button>
|
||||
)}
|
||||
</motion.div>
|
||||
|
|
|
|||
|
|
@ -336,7 +336,7 @@ export function DocumentsTableShell({
|
|||
|
||||
return (
|
||||
<motion.div
|
||||
className="rounded-lg border border-border/40 bg-background overflow-hidden"
|
||||
className="rounded-lg border border-border/40 bg-background overflow-hidden select-none"
|
||||
initial={{ opacity: 0, y: 20 }}
|
||||
animate={{ opacity: 1, y: 0 }}
|
||||
transition={{ type: "spring", stiffness: 300, damping: 30, delay: 0.2 }}
|
||||
|
|
@ -453,7 +453,7 @@ export function DocumentsTableShell({
|
|||
) : error ? (
|
||||
<div className="flex h-[50vh] w-full items-center justify-center">
|
||||
<div className="flex flex-col items-center gap-3">
|
||||
<AlertCircle className="h-8 w-8 text-destructive/60" />
|
||||
<AlertCircle className="h-8 w-8 text-destructive" />
|
||||
<p className="text-sm text-destructive">{t("error_loading")}</p>
|
||||
</div>
|
||||
</div>
|
||||
|
|
@ -482,7 +482,7 @@ export function DocumentsTableShell({
|
|||
</div>
|
||||
) : (
|
||||
<>
|
||||
{/* Desktop Table View - Notion Style */}
|
||||
{/* Desktop Table View */}
|
||||
<div className="hidden md:flex md:flex-col">
|
||||
{/* Fixed Header */}
|
||||
<Table className="table-fixed w-full">
|
||||
|
|
@ -629,7 +629,24 @@ export function DocumentsTableShell({
|
|||
)}
|
||||
{columnVisibility.created_by && (
|
||||
<TableCell className="w-36 py-2.5 text-sm text-foreground truncate border-r border-border/40">
|
||||
{doc.created_by_name || "—"}
|
||||
{doc.created_by_name ? (
|
||||
doc.created_by_email ? (
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<span className="cursor-default truncate block">
|
||||
{doc.created_by_name}
|
||||
</span>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent side="top" align="start">
|
||||
{doc.created_by_email}
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
) : (
|
||||
<span className="truncate block">{doc.created_by_name}</span>
|
||||
)
|
||||
) : (
|
||||
<span className="truncate block">{doc.created_by_email || "—"}</span>
|
||||
)}
|
||||
</TableCell>
|
||||
)}
|
||||
{columnVisibility.created_at && (
|
||||
|
|
@ -765,11 +782,11 @@ export function DocumentsTableShell({
|
|||
|
||||
{/* Document Content Viewer - lazy loads content on-demand */}
|
||||
<Dialog open={!!viewingDoc} onOpenChange={(open) => !open && handleCloseViewer()}>
|
||||
<DialogContent className="max-w-4xl max-h-[80vh] overflow-y-auto">
|
||||
<DialogHeader>
|
||||
<DialogContent className="max-w-4xl max-h-[80vh] flex flex-col overflow-hidden pb-0">
|
||||
<DialogHeader className="flex-shrink-0">
|
||||
<DialogTitle>{viewingDoc?.title}</DialogTitle>
|
||||
</DialogHeader>
|
||||
<div className="mt-4">
|
||||
<div className="mt-4 overflow-y-auto flex-1 min-h-0 px-6 select-text">
|
||||
{viewingLoading ? (
|
||||
<div className="flex items-center justify-center py-12">
|
||||
<Spinner size="lg" className="text-muted-foreground" />
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ export function PaginationControls({
|
|||
|
||||
return (
|
||||
<motion.div
|
||||
className="flex items-center justify-end gap-3 py-3 px-2"
|
||||
className="flex items-center justify-end gap-3 py-3 px-2 select-none"
|
||||
initial={{ opacity: 0, y: 10 }}
|
||||
animate={{ opacity: 1, y: 0 }}
|
||||
transition={{ type: "spring", stiffness: 300, damping: 30, delay: 0.3 }}
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ export type Document = {
|
|||
search_space_id: number;
|
||||
created_by_id?: string | null;
|
||||
created_by_name?: string | null;
|
||||
created_by_email?: string | null;
|
||||
status?: DocumentStatus;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -115,6 +115,7 @@ export default function DocumentsTable() {
|
|||
title: item.title,
|
||||
created_by_id: item.created_by_id ?? null,
|
||||
created_by_name: item.created_by_name ?? null,
|
||||
created_by_email: item.created_by_email ?? null,
|
||||
created_at: item.created_at,
|
||||
status: (
|
||||
item as {
|
||||
|
|
|
|||
|
|
@ -29,7 +29,6 @@ import {
|
|||
ChevronRight,
|
||||
ChevronUp,
|
||||
CircleAlert,
|
||||
CircleX,
|
||||
Clock,
|
||||
Columns3,
|
||||
Filter,
|
||||
|
|
@ -741,7 +740,7 @@ function LogsFilters({
|
|||
inputRef.current?.focus();
|
||||
}}
|
||||
>
|
||||
<CircleX size={16} strokeWidth={2} />
|
||||
<X size={16} strokeWidth={2} />
|
||||
</Button>
|
||||
)}
|
||||
</motion.div>
|
||||
|
|
|
|||
|
|
@ -75,7 +75,7 @@ export default function MorePagesPage() {
|
|||
const allCompleted = data?.tasks.every((t) => t.completed) ?? false;
|
||||
|
||||
return (
|
||||
<div className="flex min-h-[calc(100vh-64px)] items-center justify-center px-4 py-8">
|
||||
<div className="flex min-h-[calc(100vh-64px)] select-none items-center justify-center px-4 py-8">
|
||||
<motion.div
|
||||
initial={{ opacity: 0, y: 20 }}
|
||||
animate={{ opacity: 1, y: 0 }}
|
||||
|
|
@ -174,7 +174,7 @@ export default function MorePagesPage() {
|
|||
Contact Us
|
||||
</Button>
|
||||
</DialogTrigger>
|
||||
<DialogContent className="sm:max-w-md">
|
||||
<DialogContent className="select-none sm:max-w-md">
|
||||
<DialogHeader>
|
||||
<DialogTitle>Contact Us</DialogTitle>
|
||||
<DialogDescription>Schedule a meeting or send us an email.</DialogDescription>
|
||||
|
|
|
|||
|
|
@ -38,6 +38,10 @@ import type { ThinkingStep } from "@/components/tool-ui/deepagent-thinking";
|
|||
import { DisplayImageToolUI } from "@/components/tool-ui/display-image";
|
||||
import { GeneratePodcastToolUI } from "@/components/tool-ui/generate-podcast";
|
||||
import { GenerateReportToolUI } from "@/components/tool-ui/generate-report";
|
||||
import {
|
||||
CreateGoogleDriveFileToolUI,
|
||||
DeleteGoogleDriveFileToolUI,
|
||||
} from "@/components/tool-ui/google-drive";
|
||||
import {
|
||||
CreateLinearIssueToolUI,
|
||||
DeleteLinearIssueToolUI,
|
||||
|
|
@ -49,6 +53,7 @@ import {
|
|||
DeleteNotionPageToolUI,
|
||||
UpdateNotionPageToolUI,
|
||||
} from "@/components/tool-ui/notion";
|
||||
import { SandboxExecuteToolUI } from "@/components/tool-ui/sandbox-execute";
|
||||
import { ScrapeWebpageToolUI } from "@/components/tool-ui/scrape-webpage";
|
||||
import { RecallMemoryToolUI, SaveMemoryToolUI } from "@/components/tool-ui/user-memory";
|
||||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
|
|
@ -151,6 +156,9 @@ const TOOLS_WITH_UI = new Set([
|
|||
"create_linear_issue",
|
||||
"update_linear_issue",
|
||||
"delete_linear_issue",
|
||||
"create_google_drive_file",
|
||||
"delete_google_drive_file",
|
||||
"execute",
|
||||
// "write_todos", // Disabled for now
|
||||
]);
|
||||
|
||||
|
|
@ -1664,6 +1672,9 @@ export default function NewChatPage() {
|
|||
<CreateLinearIssueToolUI />
|
||||
<UpdateLinearIssueToolUI />
|
||||
<DeleteLinearIssueToolUI />
|
||||
<CreateGoogleDriveFileToolUI />
|
||||
<DeleteGoogleDriveFileToolUI />
|
||||
<SandboxExecuteToolUI />
|
||||
{/* <WriteTodosToolUI /> Disabled for now */}
|
||||
<div className="flex h-[calc(100dvh-64px)] overflow-hidden">
|
||||
<div className="flex-1 flex flex-col min-w-0 overflow-hidden">
|
||||
|
|
|
|||
|
|
@ -259,7 +259,7 @@ export default function OnboardPage() {
|
|||
You can add more configurations and customize settings anytime in{" "}
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => router.push(`/dashboard/${searchSpaceId}/settings`)}
|
||||
onClick={() => router.push(`/dashboard/${searchSpaceId}/settings?section=general`)}
|
||||
className="text-violet-500 hover:underline"
|
||||
>
|
||||
Settings
|
||||
|
|
|
|||
|
|
@ -12,10 +12,11 @@ import {
|
|||
Menu,
|
||||
MessageSquare,
|
||||
Settings,
|
||||
Shield,
|
||||
X,
|
||||
} from "lucide-react";
|
||||
import { AnimatePresence, motion } from "motion/react";
|
||||
import { useParams, useRouter } from "next/navigation";
|
||||
import { useParams, useRouter, useSearchParams } from "next/navigation";
|
||||
import { useTranslations } from "next-intl";
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
import { PublicChatSnapshotsManager } from "@/components/public-chat-snapshots/public-chat-snapshots-manager";
|
||||
|
|
@ -24,6 +25,7 @@ import { ImageModelManager } from "@/components/settings/image-model-manager";
|
|||
import { LLMRoleManager } from "@/components/settings/llm-role-manager";
|
||||
import { ModelConfigManager } from "@/components/settings/model-config-manager";
|
||||
import { PromptConfigManager } from "@/components/settings/prompt-config-manager";
|
||||
import { RolesManager } from "@/components/settings/roles-manager";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { trackSettingsViewed } from "@/lib/posthog/events";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
|
@ -72,6 +74,12 @@ const settingsNavItems: SettingsNavItem[] = [
|
|||
descriptionKey: "nav_public_links_desc",
|
||||
icon: Globe,
|
||||
},
|
||||
{
|
||||
id: "team-roles",
|
||||
labelKey: "nav_team_roles",
|
||||
descriptionKey: "nav_team_roles_desc",
|
||||
icon: Shield,
|
||||
},
|
||||
];
|
||||
|
||||
function SettingsSidebar({
|
||||
|
|
@ -240,7 +248,7 @@ function SettingsContent({
|
|||
{/* Section Header */}
|
||||
<AnimatePresence mode="wait">
|
||||
<motion.div
|
||||
key={activeSection + "-header"}
|
||||
key={`${activeSection}-header`}
|
||||
initial={{ opacity: 0, y: 10 }}
|
||||
animate={{ opacity: 1, y: 0 }}
|
||||
exit={{ opacity: 0, y: -10 }}
|
||||
|
|
@ -298,6 +306,7 @@ function SettingsContent({
|
|||
{activeSection === "public-links" && (
|
||||
<PublicChatSnapshotsManager searchSpaceId={searchSpaceId} />
|
||||
)}
|
||||
{activeSection === "team-roles" && <RolesManager searchSpaceId={searchSpaceId} />}
|
||||
</motion.div>
|
||||
</AnimatePresence>
|
||||
</div>
|
||||
|
|
@ -306,14 +315,27 @@ function SettingsContent({
|
|||
);
|
||||
}
|
||||
|
||||
const VALID_SECTIONS = new Set(settingsNavItems.map((item) => item.id));
|
||||
const DEFAULT_SECTION = "general";
|
||||
|
||||
export default function SettingsPage() {
|
||||
const router = useRouter();
|
||||
const params = useParams();
|
||||
const searchParams = useSearchParams();
|
||||
const searchSpaceId = Number(params.search_space_id);
|
||||
const [activeSection, setActiveSection] = useState("general");
|
||||
const [isSidebarOpen, setIsSidebarOpen] = useState(false);
|
||||
|
||||
// Track settings section view
|
||||
const sectionParam = searchParams.get("section");
|
||||
const activeSection =
|
||||
sectionParam && VALID_SECTIONS.has(sectionParam) ? sectionParam : DEFAULT_SECTION;
|
||||
|
||||
const handleSectionChange = useCallback(
|
||||
(section: string) => {
|
||||
router.replace(`/dashboard/${searchSpaceId}/settings?section=${section}`, { scroll: false });
|
||||
},
|
||||
[router, searchSpaceId]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
trackSettingsViewed(searchSpaceId, activeSection);
|
||||
}, [searchSpaceId, activeSection]);
|
||||
|
|
@ -333,7 +355,7 @@ export default function SettingsPage() {
|
|||
<div className="flex h-full w-full overflow-hidden bg-background md:rounded-xl md:border md:shadow-sm">
|
||||
<SettingsSidebar
|
||||
activeSection={activeSection}
|
||||
onSectionChange={setActiveSection}
|
||||
onSectionChange={handleSectionChange}
|
||||
onBackToApp={handleBackToApp}
|
||||
isOpen={isSidebarOpen}
|
||||
onClose={() => setIsSidebarOpen(false)}
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -221,7 +221,7 @@ export const ConnectorIndicator: FC<{ hideTrigger?: boolean }> = ({ hideTrigger
|
|||
</TooltipIconButton>
|
||||
)}
|
||||
|
||||
<DialogContent className="max-w-3xl w-[95vw] sm:w-full h-[75vh] sm:h-[85vh] flex flex-col p-0 gap-0 overflow-hidden border border-border bg-muted text-foreground focus:outline-none focus:ring-0 focus-visible:outline-none focus-visible:ring-0 [&>button]:right-4 sm:[&>button]:right-12 [&>button]:top-6 sm:[&>button]:top-10 [&>button]:opacity-80 hover:[&>button]:opacity-100 [&>button_svg]:size-5">
|
||||
<DialogContent className="max-w-3xl w-[95vw] sm:w-full h-[75vh] sm:h-[85vh] flex flex-col p-0 gap-0 overflow-hidden border border-border bg-muted text-foreground focus:outline-none focus:ring-0 focus-visible:outline-none focus-visible:ring-0 [&>button]:right-4 sm:[&>button]:right-12 [&>button]:top-6 sm:[&>button]:top-10 [&>button]:opacity-80 hover:[&>button]:opacity-100 [&>button_svg]:size-5 select-none">
|
||||
<DialogTitle className="sr-only">Manage Connectors</DialogTitle>
|
||||
{/* YouTube Crawler View - shown when adding YouTube videos */}
|
||||
{isYouTubeView && searchSpaceId ? (
|
||||
|
|
@ -374,7 +374,7 @@ export const ConnectorIndicator: FC<{ hideTrigger?: boolean }> = ({ hideTrigger
|
|||
: "You need to configure a Document Summary LLM before adding connectors. This LLM is used to process and summarize documents from your connected sources."}
|
||||
</p>
|
||||
<Button asChild size="sm" variant="outline">
|
||||
<Link href={`/dashboard/${searchSpaceId}/settings`}>
|
||||
<Link href={`/dashboard/${searchSpaceId}/settings?section=models`}>
|
||||
<Settings className="mr-2 h-4 w-4" />
|
||||
Go to Settings
|
||||
</Link>
|
||||
|
|
|
|||
|
|
@ -120,7 +120,7 @@ const DocumentUploadPopupContent: FC<{
|
|||
|
||||
return (
|
||||
<Dialog open={isOpen} onOpenChange={onOpenChange}>
|
||||
<DialogContent className="max-w-4xl w-[95vw] sm:w-full h-[calc(100dvh-2rem)] sm:h-[85vh] flex flex-col p-0 gap-0 overflow-hidden border border-border bg-muted text-foreground [&>button]:right-3 sm:[&>button]:right-12 [&>button]:top-3 sm:[&>button]:top-10 [&>button]:opacity-80 hover:[&>button]:opacity-100 [&>button]:z-[100] [&>button_svg]:size-4 sm:[&>button_svg]:size-5">
|
||||
<DialogContent className="select-none max-w-4xl w-[95vw] sm:w-full h-[calc(100dvh-2rem)] sm:h-[85vh] flex flex-col p-0 gap-0 overflow-hidden border border-border bg-muted text-foreground [&>button]:right-3 sm:[&>button]:right-12 [&>button]:top-3 sm:[&>button]:top-10 [&>button]:opacity-80 hover:[&>button]:opacity-100 [&>button]:z-[100] [&>button_svg]:size-4 sm:[&>button_svg]:size-5">
|
||||
<DialogTitle className="sr-only">Upload Document</DialogTitle>
|
||||
|
||||
{/* Scrollable container for mobile */}
|
||||
|
|
@ -129,9 +129,6 @@ const DocumentUploadPopupContent: FC<{
|
|||
<div className="sticky top-0 z-20 bg-muted px-4 sm:px-12 pt-4 sm:pt-10 pb-2 sm:pb-0">
|
||||
{/* Upload header */}
|
||||
<div className="flex items-center gap-2 sm:gap-4 mb-2 sm:mb-6">
|
||||
<div className="flex h-9 w-9 sm:h-14 sm:w-14 items-center justify-center rounded-lg sm:rounded-xl bg-primary/10 border border-primary/20 flex-shrink-0">
|
||||
<Upload className="size-4 sm:size-7 text-primary" />
|
||||
</div>
|
||||
<div className="flex-1 min-w-0 pr-8 sm:pr-0">
|
||||
<h2 className="text-base sm:text-2xl font-semibold tracking-tight">
|
||||
Upload Documents
|
||||
|
|
@ -156,7 +153,7 @@ const DocumentUploadPopupContent: FC<{
|
|||
: "You need to configure a Document Summary LLM before uploading files. This LLM is used to process and summarize your uploaded documents."}
|
||||
</p>
|
||||
<Button asChild size="sm" variant="outline">
|
||||
<Link href={`/dashboard/${searchSpaceId}/settings`}>
|
||||
<Link href={`/dashboard/${searchSpaceId}/settings?section=models`}>
|
||||
<Settings className="mr-2 h-4 w-4" />
|
||||
Go to Settings
|
||||
</Link>
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ import {
|
|||
ChevronLeftIcon,
|
||||
ChevronRightIcon,
|
||||
CopyIcon,
|
||||
Dot,
|
||||
DownloadIcon,
|
||||
FileWarning,
|
||||
Paperclip,
|
||||
|
|
@ -81,6 +82,10 @@ const CYCLING_PLACEHOLDERS = [
|
|||
const CHAT_UPLOAD_ACCEPT =
|
||||
".pdf,.doc,.docx,.txt,.md,.markdown,.ppt,.pptx,.xls,.xlsx,.xlsm,.xlsb,.csv,.html,.htm,.xml,.rtf,.epub,.jpg,.jpeg,.png,.bmp,.webp,.tiff,.tif,.mp3,.mp4,.mpeg,.mpga,.m4a,.wav,.webm";
|
||||
|
||||
const CHAT_MAX_FILES = 10;
|
||||
const CHAT_MAX_FILE_SIZE_BYTES = 50 * 1024 * 1024; // 50 MB per file
|
||||
const CHAT_MAX_TOTAL_SIZE_BYTES = 200 * 1024 * 1024; // 200 MB total
|
||||
|
||||
type UploadState = "pending" | "processing" | "ready" | "failed";
|
||||
|
||||
interface UploadedMentionDoc {
|
||||
|
|
@ -533,6 +538,28 @@ const Composer: FC = () => {
|
|||
event.target.value = "";
|
||||
if (files.length === 0 || !search_space_id) return;
|
||||
|
||||
if (files.length > CHAT_MAX_FILES) {
|
||||
toast.error(`Too many files. Maximum ${CHAT_MAX_FILES} files per upload.`);
|
||||
return;
|
||||
}
|
||||
|
||||
let totalSize = 0;
|
||||
for (const file of files) {
|
||||
if (file.size > CHAT_MAX_FILE_SIZE_BYTES) {
|
||||
toast.error(
|
||||
`File "${file.name}" (${(file.size / (1024 * 1024)).toFixed(1)} MB) exceeds the ${CHAT_MAX_FILE_SIZE_BYTES / (1024 * 1024)} MB per-file limit.`
|
||||
);
|
||||
return;
|
||||
}
|
||||
totalSize += file.size;
|
||||
}
|
||||
if (totalSize > CHAT_MAX_TOTAL_SIZE_BYTES) {
|
||||
toast.error(
|
||||
`Total upload size (${(totalSize / (1024 * 1024)).toFixed(1)} MB) exceeds the ${CHAT_MAX_TOTAL_SIZE_BYTES / (1024 * 1024)} MB limit.`
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
setIsUploadingDocs(true);
|
||||
try {
|
||||
const uploadResponse = await documentsApiService.uploadDocument({
|
||||
|
|
@ -745,7 +772,19 @@ const ComposerAction: FC<ComposerActionProps> = ({
|
|||
<div className="aui-composer-action-wrapper relative mx-2 mb-2 flex items-center justify-between">
|
||||
<div className="flex items-center gap-1">
|
||||
<TooltipIconButton
|
||||
tooltip={isUploadingDocs ? "Uploading documents..." : "Upload and mention files"}
|
||||
tooltip={
|
||||
isUploadingDocs ? (
|
||||
"Uploading documents..."
|
||||
) : (
|
||||
<div className="flex flex-col gap-0.5">
|
||||
<span className="font-medium">Upload and mention files</span>
|
||||
<span className="text-xs text-muted-foreground flex items-center">
|
||||
Max 10 files <Dot className="size-3" /> 50 MB each
|
||||
</span>
|
||||
<span className="text-xs text-muted-foreground">Total upload limit: 200 MB</span>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
side="bottom"
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
"use client";
|
||||
|
||||
import { Slottable } from "@radix-ui/react-slot";
|
||||
import { type ComponentPropsWithRef, forwardRef } from "react";
|
||||
import { type ComponentPropsWithRef, forwardRef, type ReactNode } from "react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
export type TooltipIconButtonProps = ComponentPropsWithRef<typeof Button> & {
|
||||
tooltip: string;
|
||||
tooltip: ReactNode;
|
||||
side?: "top" | "bottom" | "left" | "right";
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -196,7 +196,7 @@ export function DashboardBreadcrumb() {
|
|||
}
|
||||
|
||||
return (
|
||||
<Breadcrumb>
|
||||
<Breadcrumb className="select-none">
|
||||
<BreadcrumbList>
|
||||
{breadcrumbs.map((item, index) => (
|
||||
<React.Fragment key={`${index}-${item.href || item.label}`}>
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import Link from "next/link";
|
|||
import type React from "react";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import Balancer from "react-wrap-balancer";
|
||||
import { WalkthroughScroll } from "@/components/ui/walkthrough-scroll";
|
||||
import { HeroCarousel } from "@/components/ui/hero-carousel";
|
||||
import { AUTH_TYPE, BACKEND_URL } from "@/lib/env-config";
|
||||
import { trackLoginAttempt } from "@/lib/posthog/events";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
|
@ -97,18 +97,18 @@ export function HeroSection() {
|
|||
)}
|
||||
</h2>
|
||||
{/* // TODO:aCTUAL DESCRITION */}
|
||||
<p className="relative z-50 mx-auto mt-4 max-w-lg px-4 text-center text-base/6 text-gray-600 dark:text-gray-200">
|
||||
<p className="relative z-50 mx-auto mt-4 max-w-xl px-4 text-center text-base/6 text-gray-600 dark:text-gray-200">
|
||||
Connect any AI to your documents, Drive, Notion and more,
|
||||
</p>
|
||||
<p className="relative z-50 mx-auto mt-0 max-w-lg px-4 text-center text-base/6 text-gray-600 dark:text-gray-200">
|
||||
then chat with it, invite your team, or generate podcasts and reports.
|
||||
<p className="relative z-50 mx-auto mt-0 max-w-xl px-4 text-center text-base/6 text-gray-600 dark:text-gray-200">
|
||||
then chat with it, generate podcasts and reports, or even invite your team.
|
||||
</p>
|
||||
<div className="mb-6 mt-6 flex w-full flex-col items-center justify-center gap-4 px-8 sm:flex-row md:mb-10">
|
||||
<GetStartedButton />
|
||||
{/* <ContactSalesButton /> */}
|
||||
</div>
|
||||
<div ref={containerRef} className="relative w-full">
|
||||
<WalkthroughScroll />
|
||||
<div ref={containerRef} className="relative w-full z-51">
|
||||
<HeroCarousel />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path fill="currentColor" d="M4.94 4.96a9.97 9.97 0 0 1 10.835-2.182a8.7 8.7 0 0 1 2.033 1.11l-3.006 1.39C12.003 4.101 8.797 4.9 6.84 6.86c-2.564 2.565-3.146 6.954-.36 9.922l.278.284L.124 23c1.875-1.973 3.771-4.427 2.636-7.19c-1.52-3.698-.635-8.03 2.18-10.85M23.9.1c-2.264 3.174-3.184 5.389-2.197 9.64l-.007-.007c.753 3.201-.052 6.75-2.653 9.355c-3.279 3.285-8.526 4.016-12.847 1.06L9.21 18.75c2.758 1.084 5.775.607 7.943-1.564c2.169-2.17 2.655-5.332 1.566-7.963c-.207-.5-.828-.625-1.263-.304L8.59 15.472l12.7-12.77v.01z"/></svg>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path fill="currentColor" d="M6.469 8.776L16.512 23h-4.464L2.005 8.776H6.47zm-.004 7.9l2.233 3.164L6.467 23H2l4.465-6.324zM22 2.582V23h-3.659V7.764L22 2.582zM22 1l-9.952 14.095-2.233-3.163L17.533 1H22z"/></svg>
|
||||
|
Before Width: | Height: | Size: 589 B After Width: | Height: | Size: 270 B |
|
|
@ -334,7 +334,7 @@ export function LayoutDataProvider({
|
|||
|
||||
const handleSearchSpaceSettings = useCallback(
|
||||
(space: SearchSpace) => {
|
||||
router.push(`/dashboard/${space.id}/settings`);
|
||||
router.push(`/dashboard/${space.id}/settings?section=general`);
|
||||
},
|
||||
[router]
|
||||
);
|
||||
|
|
@ -478,7 +478,7 @@ export function LayoutDataProvider({
|
|||
);
|
||||
|
||||
const handleSettings = useCallback(() => {
|
||||
router.push(`/dashboard/${searchSpaceId}/settings`);
|
||||
router.push(`/dashboard/${searchSpaceId}/settings?section=general`);
|
||||
}, [router, searchSpaceId]);
|
||||
|
||||
const handleManageMembers = useCallback(() => {
|
||||
|
|
@ -703,7 +703,6 @@ export function LayoutDataProvider({
|
|||
<DialogContent className="sm:max-w-md">
|
||||
<DialogHeader>
|
||||
<DialogTitle className="flex items-center gap-2">
|
||||
<PencilIcon className="h-5 w-5" />
|
||||
<span>{tSidebar("rename_chat") || "Rename Chat"}</span>
|
||||
</DialogTitle>
|
||||
<DialogDescription>
|
||||
|
|
@ -736,7 +735,7 @@ export function LayoutDataProvider({
|
|||
{isRenamingChat ? (
|
||||
<>
|
||||
<span className="h-4 w-4 animate-spin rounded-full border-2 border-current border-t-transparent" />
|
||||
{tSidebar("renaming") || "Renaming..."}
|
||||
{tSidebar("renaming") || "Renaming"}
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
import { zodResolver } from "@hookform/resolvers/zod";
|
||||
import { useAtomValue } from "jotai";
|
||||
import { Plus, Search } from "lucide-react";
|
||||
import { useTranslations } from "next-intl";
|
||||
import { useState } from "react";
|
||||
import { useForm } from "react-hook-form";
|
||||
|
|
@ -86,9 +85,6 @@ export function CreateSearchSpaceDialog({ open, onOpenChange }: CreateSearchSpac
|
|||
<DialogContent className="max-w-[90vw] sm:max-w-sm p-4 sm:p-5 data-[state=open]:animate-none data-[state=closed]:animate-none">
|
||||
<DialogHeader className="space-y-2 pb-2">
|
||||
<div className="flex items-center gap-2 sm:gap-3">
|
||||
<div className="flex h-8 w-8 sm:h-10 sm:w-10 items-center justify-center rounded-lg bg-primary/10 flex-shrink-0">
|
||||
<Search className="h-4 w-4 sm:h-5 sm:w-5 text-primary" />
|
||||
</div>
|
||||
<div className="flex-1 min-w-0">
|
||||
<DialogTitle className="text-base sm:text-lg">{t("create_title")}</DialogTitle>
|
||||
<DialogDescription className="text-xs sm:text-sm mt-0.5">
|
||||
|
|
@ -142,20 +138,20 @@ export function CreateSearchSpaceDialog({ open, onOpenChange }: CreateSearchSpac
|
|||
)}
|
||||
/>
|
||||
|
||||
<DialogFooter className="flex-col sm:flex-row gap-2 pt-2 sm:pt-3">
|
||||
<DialogFooter className="flex-row gap-2 pt-2 sm:pt-3">
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
onClick={() => handleOpenChange(false)}
|
||||
disabled={isSubmitting}
|
||||
className="w-full sm:w-auto h-9 sm:h-10 text-sm"
|
||||
className="flex-1 sm:flex-none sm:w-auto h-8 sm:h-10 text-xs sm:text-sm"
|
||||
>
|
||||
{tCommon("cancel")}
|
||||
</Button>
|
||||
<Button
|
||||
type="submit"
|
||||
disabled={isSubmitting}
|
||||
className="w-full sm:w-auto h-9 sm:h-10 text-sm"
|
||||
className="flex-1 sm:flex-none sm:w-auto h-8 sm:h-10 text-xs sm:text-sm"
|
||||
>
|
||||
{isSubmitting ? (
|
||||
<>
|
||||
|
|
@ -163,10 +159,7 @@ export function CreateSearchSpaceDialog({ open, onOpenChange }: CreateSearchSpac
|
|||
{t("creating")}
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<Plus className="-mr-1 h-4 w-4" />
|
||||
{t("create_button")}
|
||||
</>
|
||||
<>{t("create_button")}</>
|
||||
)}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
|
|
|
|||
|
|
@ -210,6 +210,26 @@ export function LayoutShell({
|
|||
onCloseMobileSidebar={() => setMobileMenuOpen(false)}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Mobile All Shared Chats - slide-out panel */}
|
||||
{allSharedChatsPanel && (
|
||||
<AllSharedChatsSidebar
|
||||
open={allSharedChatsPanel.open}
|
||||
onOpenChange={allSharedChatsPanel.onOpenChange}
|
||||
searchSpaceId={allSharedChatsPanel.searchSpaceId}
|
||||
onCloseMobileSidebar={() => setMobileMenuOpen(false)}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Mobile All Private Chats - slide-out panel */}
|
||||
{allPrivateChatsPanel && (
|
||||
<AllPrivateChatsSidebar
|
||||
open={allPrivateChatsPanel.open}
|
||||
onOpenChange={allPrivateChatsPanel.onOpenChange}
|
||||
searchSpaceId={allPrivateChatsPanel.searchSpaceId}
|
||||
onCloseMobileSidebar={() => setMobileMenuOpen(false)}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</TooltipProvider>
|
||||
</SidebarProvider>
|
||||
|
|
|
|||
|
|
@ -4,8 +4,10 @@ import { useQuery, useQueryClient } from "@tanstack/react-query";
|
|||
import { format } from "date-fns";
|
||||
import {
|
||||
ArchiveIcon,
|
||||
ChevronLeft,
|
||||
MessageCircleMore,
|
||||
MoreHorizontal,
|
||||
PenLine,
|
||||
RotateCcwIcon,
|
||||
Search,
|
||||
Trash2,
|
||||
|
|
@ -17,6 +19,14 @@ import { useTranslations } from "next-intl";
|
|||
import { useCallback, useEffect, useMemo, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from "@/components/ui/dialog";
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
|
|
@ -69,6 +79,10 @@ export function AllPrivateChatsSidebar({
|
|||
const [searchQuery, setSearchQuery] = useState("");
|
||||
const [showArchived, setShowArchived] = useState(false);
|
||||
const [openDropdownId, setOpenDropdownId] = useState<number | null>(null);
|
||||
const [showRenameDialog, setShowRenameDialog] = useState(false);
|
||||
const [renamingThread, setRenamingThread] = useState<{ id: number; title: string } | null>(null);
|
||||
const [newTitle, setNewTitle] = useState("");
|
||||
const [isRenaming, setIsRenaming] = useState(false);
|
||||
const debouncedSearchQuery = useDebouncedValue(searchQuery, 300);
|
||||
|
||||
const isSearchMode = !!debouncedSearchQuery.trim();
|
||||
|
|
@ -187,6 +201,35 @@ export function AllPrivateChatsSidebar({
|
|||
[queryClient, searchSpaceId, t]
|
||||
);
|
||||
|
||||
const handleStartRename = useCallback((threadId: number, title: string) => {
|
||||
setRenamingThread({ id: threadId, title });
|
||||
setNewTitle(title);
|
||||
setShowRenameDialog(true);
|
||||
}, []);
|
||||
|
||||
const handleConfirmRename = useCallback(async () => {
|
||||
if (!renamingThread || !newTitle.trim()) return;
|
||||
setIsRenaming(true);
|
||||
try {
|
||||
await updateThread(renamingThread.id, { title: newTitle.trim() });
|
||||
toast.success(t("chat_renamed") || "Chat renamed");
|
||||
queryClient.invalidateQueries({ queryKey: ["all-threads", searchSpaceId] });
|
||||
queryClient.invalidateQueries({ queryKey: ["search-threads", searchSpaceId] });
|
||||
queryClient.invalidateQueries({ queryKey: ["threads", searchSpaceId] });
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: ["threads", searchSpaceId, "detail", String(renamingThread.id)],
|
||||
});
|
||||
} catch (error) {
|
||||
console.error("Error renaming thread:", error);
|
||||
toast.error(t("error_renaming_chat") || "Failed to rename chat");
|
||||
} finally {
|
||||
setIsRenaming(false);
|
||||
setShowRenameDialog(false);
|
||||
setRenamingThread(null);
|
||||
setNewTitle("");
|
||||
}
|
||||
}, [renamingThread, newTitle, queryClient, searchSpaceId, t]);
|
||||
|
||||
const handleClearSearch = useCallback(() => {
|
||||
setSearchQuery("");
|
||||
}, []);
|
||||
|
|
@ -205,6 +248,17 @@ export function AllPrivateChatsSidebar({
|
|||
>
|
||||
<div className="shrink-0 p-4 pb-2 space-y-3">
|
||||
<div className="flex items-center gap-2">
|
||||
{isMobile && (
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
className="h-8 w-8 rounded-full"
|
||||
onClick={() => onOpenChange(false)}
|
||||
>
|
||||
<ChevronLeft className="h-4 w-4 text-muted-foreground" />
|
||||
<span className="sr-only">{t("close") || "Close"}</span>
|
||||
</Button>
|
||||
)}
|
||||
<User className="h-5 w-5 text-primary" />
|
||||
<h2 className="text-lg font-semibold">{t("chats") || "Private Chats"}</h2>
|
||||
</div>
|
||||
|
|
@ -356,6 +410,14 @@ export function AllPrivateChatsSidebar({
|
|||
</Button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="end" className="w-40 z-80">
|
||||
{!thread.archived && (
|
||||
<DropdownMenuItem
|
||||
onClick={() => handleStartRename(thread.id, thread.title || "New Chat")}
|
||||
>
|
||||
<PenLine className="mr-2 h-4 w-4" />
|
||||
<span>{t("rename") || "Rename"}</span>
|
||||
</DropdownMenuItem>
|
||||
)}
|
||||
<DropdownMenuItem
|
||||
onClick={() => handleToggleArchive(thread.id, thread.archived)}
|
||||
disabled={isArchiving}
|
||||
|
|
@ -412,6 +474,51 @@ export function AllPrivateChatsSidebar({
|
|||
</div>
|
||||
)}
|
||||
</div>
|
||||
<Dialog open={showRenameDialog} onOpenChange={setShowRenameDialog}>
|
||||
<DialogContent className="sm:max-w-md">
|
||||
<DialogHeader>
|
||||
<DialogTitle className="flex items-center gap-2">
|
||||
<span>{t("rename_chat") || "Rename Chat"}</span>
|
||||
</DialogTitle>
|
||||
<DialogDescription>
|
||||
{t("rename_chat_description") || "Enter a new name for this conversation."}
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
<Input
|
||||
value={newTitle}
|
||||
onChange={(e) => setNewTitle(e.target.value)}
|
||||
placeholder={t("chat_title_placeholder") || "Chat title"}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter" && !isRenaming && newTitle.trim()) {
|
||||
handleConfirmRename();
|
||||
}
|
||||
}}
|
||||
/>
|
||||
<DialogFooter className="flex gap-2 sm:justify-end">
|
||||
<Button
|
||||
variant="outline"
|
||||
onClick={() => setShowRenameDialog(false)}
|
||||
disabled={isRenaming}
|
||||
>
|
||||
{t("cancel")}
|
||||
</Button>
|
||||
<Button
|
||||
onClick={handleConfirmRename}
|
||||
disabled={isRenaming || !newTitle.trim()}
|
||||
className="gap-2"
|
||||
>
|
||||
{isRenaming ? (
|
||||
<>
|
||||
<Spinner size="xs" />
|
||||
<span>{t("renaming") || "Renaming"}</span>
|
||||
</>
|
||||
) : (
|
||||
<span>{t("rename") || "Rename"}</span>
|
||||
)}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
</SidebarSlideOutPanel>
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,8 +4,10 @@ import { useQuery, useQueryClient } from "@tanstack/react-query";
|
|||
import { format } from "date-fns";
|
||||
import {
|
||||
ArchiveIcon,
|
||||
ChevronLeft,
|
||||
MessageCircleMore,
|
||||
MoreHorizontal,
|
||||
PenLine,
|
||||
RotateCcwIcon,
|
||||
Search,
|
||||
Trash2,
|
||||
|
|
@ -17,6 +19,14 @@ import { useTranslations } from "next-intl";
|
|||
import { useCallback, useEffect, useMemo, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from "@/components/ui/dialog";
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
|
|
@ -69,6 +79,10 @@ export function AllSharedChatsSidebar({
|
|||
const [searchQuery, setSearchQuery] = useState("");
|
||||
const [showArchived, setShowArchived] = useState(false);
|
||||
const [openDropdownId, setOpenDropdownId] = useState<number | null>(null);
|
||||
const [showRenameDialog, setShowRenameDialog] = useState(false);
|
||||
const [renamingThread, setRenamingThread] = useState<{ id: number; title: string } | null>(null);
|
||||
const [newTitle, setNewTitle] = useState("");
|
||||
const [isRenaming, setIsRenaming] = useState(false);
|
||||
const debouncedSearchQuery = useDebouncedValue(searchQuery, 300);
|
||||
|
||||
const isSearchMode = !!debouncedSearchQuery.trim();
|
||||
|
|
@ -187,6 +201,35 @@ export function AllSharedChatsSidebar({
|
|||
[queryClient, searchSpaceId, t]
|
||||
);
|
||||
|
||||
const handleStartRename = useCallback((threadId: number, title: string) => {
|
||||
setRenamingThread({ id: threadId, title });
|
||||
setNewTitle(title);
|
||||
setShowRenameDialog(true);
|
||||
}, []);
|
||||
|
||||
const handleConfirmRename = useCallback(async () => {
|
||||
if (!renamingThread || !newTitle.trim()) return;
|
||||
setIsRenaming(true);
|
||||
try {
|
||||
await updateThread(renamingThread.id, { title: newTitle.trim() });
|
||||
toast.success(t("chat_renamed") || "Chat renamed");
|
||||
queryClient.invalidateQueries({ queryKey: ["all-threads", searchSpaceId] });
|
||||
queryClient.invalidateQueries({ queryKey: ["search-threads", searchSpaceId] });
|
||||
queryClient.invalidateQueries({ queryKey: ["threads", searchSpaceId] });
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: ["threads", searchSpaceId, "detail", String(renamingThread.id)],
|
||||
});
|
||||
} catch (error) {
|
||||
console.error("Error renaming thread:", error);
|
||||
toast.error(t("error_renaming_chat") || "Failed to rename chat");
|
||||
} finally {
|
||||
setIsRenaming(false);
|
||||
setShowRenameDialog(false);
|
||||
setRenamingThread(null);
|
||||
setNewTitle("");
|
||||
}
|
||||
}, [renamingThread, newTitle, queryClient, searchSpaceId, t]);
|
||||
|
||||
const handleClearSearch = useCallback(() => {
|
||||
setSearchQuery("");
|
||||
}, []);
|
||||
|
|
@ -205,6 +248,17 @@ export function AllSharedChatsSidebar({
|
|||
>
|
||||
<div className="shrink-0 p-4 pb-2 space-y-3">
|
||||
<div className="flex items-center gap-2">
|
||||
{isMobile && (
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
className="h-8 w-8 rounded-full"
|
||||
onClick={() => onOpenChange(false)}
|
||||
>
|
||||
<ChevronLeft className="h-4 w-4 text-muted-foreground" />
|
||||
<span className="sr-only">{t("close") || "Close"}</span>
|
||||
</Button>
|
||||
)}
|
||||
<Users className="h-5 w-5 text-primary" />
|
||||
<h2 className="text-lg font-semibold">{t("shared_chats") || "Shared Chats"}</h2>
|
||||
</div>
|
||||
|
|
@ -356,6 +410,14 @@ export function AllSharedChatsSidebar({
|
|||
</Button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="end" className="w-40 z-80">
|
||||
{!thread.archived && (
|
||||
<DropdownMenuItem
|
||||
onClick={() => handleStartRename(thread.id, thread.title || "New Chat")}
|
||||
>
|
||||
<PenLine className="mr-2 h-4 w-4" />
|
||||
<span>{t("rename") || "Rename"}</span>
|
||||
</DropdownMenuItem>
|
||||
)}
|
||||
<DropdownMenuItem
|
||||
onClick={() => handleToggleArchive(thread.id, thread.archived)}
|
||||
disabled={isArchiving}
|
||||
|
|
@ -412,6 +474,51 @@ export function AllSharedChatsSidebar({
|
|||
</div>
|
||||
)}
|
||||
</div>
|
||||
<Dialog open={showRenameDialog} onOpenChange={setShowRenameDialog}>
|
||||
<DialogContent className="sm:max-w-md">
|
||||
<DialogHeader>
|
||||
<DialogTitle className="flex items-center gap-2">
|
||||
<span>{t("rename_chat") || "Rename Chat"}</span>
|
||||
</DialogTitle>
|
||||
<DialogDescription>
|
||||
{t("rename_chat_description") || "Enter a new name for this conversation."}
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
<Input
|
||||
value={newTitle}
|
||||
onChange={(e) => setNewTitle(e.target.value)}
|
||||
placeholder={t("chat_title_placeholder") || "Chat title"}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter" && !isRenaming && newTitle.trim()) {
|
||||
handleConfirmRename();
|
||||
}
|
||||
}}
|
||||
/>
|
||||
<DialogFooter className="flex gap-2 sm:justify-end">
|
||||
<Button
|
||||
variant="outline"
|
||||
onClick={() => setShowRenameDialog(false)}
|
||||
disabled={isRenaming}
|
||||
>
|
||||
{t("cancel")}
|
||||
</Button>
|
||||
<Button
|
||||
onClick={handleConfirmRename}
|
||||
disabled={isRenaming || !newTitle.trim()}
|
||||
className="gap-2"
|
||||
>
|
||||
{isRenaming ? (
|
||||
<>
|
||||
<Spinner size="xs" />
|
||||
<span>{t("renaming") || "Renaming"}</span>
|
||||
</>
|
||||
) : (
|
||||
<span>{t("rename") || "Rename"}</span>
|
||||
)}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
</SidebarSlideOutPanel>
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import {
|
|||
ArchiveIcon,
|
||||
MessageSquare,
|
||||
MoreHorizontal,
|
||||
PencilIcon,
|
||||
PenLine,
|
||||
RotateCcwIcon,
|
||||
Trash2,
|
||||
} from "lucide-react";
|
||||
|
|
@ -74,7 +74,7 @@ export function ChatListItem({
|
|||
onRename();
|
||||
}}
|
||||
>
|
||||
<PencilIcon className="mr-2 h-4 w-4" />
|
||||
<PenLine className="mr-2 h-4 w-4" />
|
||||
<span>{t("rename") || "Rename"}</span>
|
||||
</DropdownMenuItem>
|
||||
)}
|
||||
|
|
|
|||
|
|
@ -727,7 +727,7 @@ export function InboxSidebar({
|
|||
</Tooltip>
|
||||
<DropdownMenuContent
|
||||
align="end"
|
||||
className={cn("z-80", activeTab === "status" ? "w-52" : "w-44")}
|
||||
className={cn("z-80 select-none", activeTab === "status" ? "w-52" : "w-44")}
|
||||
>
|
||||
<DropdownMenuLabel className="text-xs text-muted-foreground/80 font-normal">
|
||||
{t("filter") || "Filter"}
|
||||
|
|
|
|||
|
|
@ -131,7 +131,7 @@ export function MobileSidebar({
|
|||
</div>
|
||||
|
||||
{/* Sidebar Content - right side */}
|
||||
<div className="flex-1 overflow-hidden flex flex-col">
|
||||
<div className="flex-1 overflow-hidden flex flex-col [&>*]:!w-full">
|
||||
<Sidebar
|
||||
searchSpace={searchSpace}
|
||||
isCollapsed={false}
|
||||
|
|
|
|||
|
|
@ -93,7 +93,7 @@ export function Sidebar({
|
|||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"relative flex h-full flex-col bg-sidebar text-sidebar-foreground overflow-hidden",
|
||||
"relative flex h-full flex-col bg-sidebar text-sidebar-foreground overflow-hidden select-none",
|
||||
isCollapsed ? "w-[60px] transition-all duration-200" : "",
|
||||
!isCollapsed && !isResizing ? "transition-all duration-200" : "",
|
||||
className
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"use client";
|
||||
|
||||
import { ChevronsUpDown, Settings, Users } from "lucide-react";
|
||||
import { ChevronsUpDown, Settings, UserPen } from "lucide-react";
|
||||
import { useParams, useRouter } from "next/navigation";
|
||||
import { useTranslations } from "next-intl";
|
||||
import { Button } from "@/components/ui/button";
|
||||
|
|
@ -51,14 +51,14 @@ export function SidebarHeader({
|
|||
<ChevronsUpDown className="h-4 w-4 shrink-0 text-muted-foreground" />
|
||||
</Button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="start" className="w-56">
|
||||
<DropdownMenuContent align="start" className="w-48">
|
||||
<DropdownMenuItem onClick={onManageMembers}>
|
||||
<Users className="mr-2 h-4 w-4" />
|
||||
<UserPen className="h-4 w-4" />
|
||||
{t("manage_members")}
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuSeparator />
|
||||
<DropdownMenuItem onClick={onSettings}>
|
||||
<Settings className="mr-2 h-4 w-4" />
|
||||
<Settings className="h-4 w-4" />
|
||||
{t("search_space_settings")}
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ export function SidebarSlideOutPanel({
|
|||
exit={{ x: "-100%" }}
|
||||
transition={{ type: "tween", duration: 0.2, ease: [0.4, 0, 0.2, 1] }}
|
||||
className={cn(
|
||||
"h-full w-full bg-background flex flex-col pointer-events-auto",
|
||||
"h-full w-full bg-background flex flex-col pointer-events-auto select-none",
|
||||
"sm:border-r sm:shadow-xl"
|
||||
)}
|
||||
role="dialog"
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue