mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-08 20:25:19 +02:00
Merge pull request #840 from MODSetter/dev
feat: Google Drive HITL tools, team management UI, Daytona sandboxes, indexing pipeline hardening & test infrastructure
This commit is contained in:
commit
3ca401cb2c
184 changed files with 17831 additions and 9340 deletions
112
.cursor/skills/tdd/SKILL.md
Normal file
112
.cursor/skills/tdd/SKILL.md
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
---
|
||||
name: tdd
|
||||
description: Strict Python TDD workflow using pytest (Red-Green-Refactor).
|
||||
---
|
||||
|
||||
---
|
||||
name: tdd
|
||||
description: Test-driven development with red-green-refactor loop. Use when user wants to build features or fix bugs using TDD, mentions "red-green-refactor", wants integration tests, or asks for test-first development.
|
||||
---
|
||||
|
||||
# Test-Driven Development
|
||||
|
||||
## Philosophy
|
||||
|
||||
**Core principle**: Tests should verify behavior through public interfaces, not implementation details. Code can change entirely; tests shouldn't.
|
||||
|
||||
**Good tests** are integration-style: they exercise real code paths through public APIs. They describe _what_ the system does, not _how_ it does it. A good test reads like a specification - "user can checkout with valid cart" tells you exactly what capability exists. These tests survive refactors because they don't care about internal structure.
|
||||
|
||||
**Bad tests** are coupled to implementation. They mock internal collaborators, test private methods, or verify through external means (like querying a database directly instead of using the interface). The warning sign: your test breaks when you refactor, but behavior hasn't changed. If you rename an internal function and tests fail, those tests were testing implementation, not behavior.
|
||||
|
||||
See [tests.md](tests.md) for examples and [mocking.md](mocking.md) for mocking guidelines.
|
||||
|
||||
## Anti-Pattern: Horizontal Slices
|
||||
|
||||
**DO NOT write all tests first, then all implementation.** This is "horizontal slicing" - treating RED as "write all tests" and GREEN as "write all code."
|
||||
|
||||
This produces **crap tests**:
|
||||
|
||||
- Tests written in bulk test _imagined_ behavior, not _actual_ behavior
|
||||
- You end up testing the _shape_ of things (data structures, function signatures) rather than user-facing behavior
|
||||
- Tests become insensitive to real changes - they pass when behavior breaks, fail when behavior is fine
|
||||
- You outrun your headlights, committing to test structure before understanding the implementation
|
||||
|
||||
**Correct approach**: Vertical slices via tracer bullets. One test → one implementation → repeat. Each test responds to what you learned from the previous cycle. Because you just wrote the code, you know exactly what behavior matters and how to verify it.
|
||||
|
||||
```
|
||||
WRONG (horizontal):
|
||||
RED: test1, test2, test3, test4, test5
|
||||
GREEN: impl1, impl2, impl3, impl4, impl5
|
||||
|
||||
RIGHT (vertical):
|
||||
RED→GREEN: test1→impl1
|
||||
RED→GREEN: test2→impl2
|
||||
RED→GREEN: test3→impl3
|
||||
...
|
||||
```
|
||||
|
||||
## Workflow
|
||||
|
||||
### 1. Planning
|
||||
|
||||
Before writing any code:
|
||||
|
||||
- [ ] Confirm with user what interface changes are needed
|
||||
- [ ] Confirm with user which behaviors to test (prioritize)
|
||||
- [ ] Identify opportunities for [deep modules](deep-modules.md) (small interface, deep implementation)
|
||||
- [ ] Design interfaces for [testability](interface-design.md)
|
||||
- [ ] List the behaviors to test (not implementation steps)
|
||||
- [ ] Get user approval on the plan
|
||||
|
||||
Ask: "What should the public interface look like? Which behaviors are most important to test?"
|
||||
|
||||
**You can't test everything.** Confirm with the user exactly which behaviors matter most. Focus testing effort on critical paths and complex logic, not every possible edge case.
|
||||
|
||||
### 2. Tracer Bullet
|
||||
|
||||
Write ONE test that confirms ONE thing about the system:
|
||||
|
||||
```
|
||||
RED: Write test for first behavior → test fails
|
||||
GREEN: Write minimal code to pass → test passes
|
||||
```
|
||||
|
||||
This is your tracer bullet - proves the path works end-to-end.
|
||||
|
||||
### 3. Incremental Loop
|
||||
|
||||
For each remaining behavior:
|
||||
|
||||
```
|
||||
RED: Write next test → fails
|
||||
GREEN: Minimal code to pass → passes
|
||||
```
|
||||
|
||||
Rules:
|
||||
|
||||
- One test at a time
|
||||
- Only enough code to pass current test
|
||||
- Don't anticipate future tests
|
||||
- Keep tests focused on observable behavior
|
||||
|
||||
### 4. Refactor
|
||||
|
||||
After all tests pass, look for [refactor candidates](refactoring.md):
|
||||
|
||||
- [ ] Extract duplication
|
||||
- [ ] Deepen modules (move complexity behind simple interfaces)
|
||||
- [ ] Apply SOLID principles where natural
|
||||
- [ ] Consider what new code reveals about existing code
|
||||
- [ ] Run tests after each refactor step
|
||||
|
||||
**Never refactor while RED.** Get to GREEN first.
|
||||
|
||||
## Checklist Per Cycle
|
||||
|
||||
```
|
||||
[ ] Test describes behavior, not implementation
|
||||
[ ] Test uses public interface only
|
||||
[ ] Test would survive internal refactor
|
||||
[ ] Code is minimal for this test
|
||||
[ ] No speculative features added
|
||||
```
|
||||
33
.cursor/skills/tdd/deep-modules.md
Normal file
33
.cursor/skills/tdd/deep-modules.md
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
# Deep Modules
|
||||
|
||||
From "A Philosophy of Software Design":
|
||||
|
||||
**Deep module** = small interface + lots of implementation
|
||||
|
||||
```
|
||||
┌─────────────────────┐
|
||||
│ Small Interface │ ← Few methods, simple params
|
||||
├─────────────────────┤
|
||||
│ │
|
||||
│ │
|
||||
│ Deep Implementation│ ← Complex logic hidden
|
||||
│ │
|
||||
│ │
|
||||
└─────────────────────┘
|
||||
```
|
||||
|
||||
**Shallow module** = large interface + little implementation (avoid)
|
||||
|
||||
```
|
||||
┌─────────────────────────────────┐
|
||||
│ Large Interface │ ← Many methods, complex params
|
||||
├─────────────────────────────────┤
|
||||
│ Thin Implementation │ ← Just passes through
|
||||
└─────────────────────────────────┘
|
||||
```
|
||||
|
||||
When designing interfaces, ask:
|
||||
|
||||
- Can I reduce the number of methods?
|
||||
- Can I simplify the parameters?
|
||||
- Can I hide more complexity inside?
|
||||
33
.cursor/skills/tdd/interface-design.md
Normal file
33
.cursor/skills/tdd/interface-design.md
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
# Interface Design for Testability
|
||||
|
||||
Good interfaces make testing natural:
|
||||
|
||||
1. **Accept dependencies, don't create them**
|
||||
```python
|
||||
# Testable
|
||||
def process_order(order, payment_gateway):
|
||||
pass
|
||||
|
||||
# Hard to test
|
||||
def process_order(order):
|
||||
gateway = StripeGateway()
|
||||
|
||||
```
|
||||
|
||||
|
||||
2. **Return results, don't produce side effects**
|
||||
```python
|
||||
# Testable
|
||||
def calculate_discount(cart) -> float:
|
||||
return discount
|
||||
|
||||
# Hard to test
|
||||
def apply_discount(cart) -> None:
|
||||
cart.total -= discount
|
||||
|
||||
```
|
||||
|
||||
|
||||
3. **Small surface area**
|
||||
* Fewer methods = fewer tests needed
|
||||
* Fewer params = simpler test setup
|
||||
69
.cursor/skills/tdd/mocking.md
Normal file
69
.cursor/skills/tdd/mocking.md
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
|
||||
# When to Mock
|
||||
|
||||
Mock at **system boundaries** only:
|
||||
|
||||
* External APIs (payment, email, etc.)
|
||||
* Databases (sometimes - prefer test DB)
|
||||
* Time/randomness
|
||||
* File system (sometimes)
|
||||
|
||||
Don't mock:
|
||||
|
||||
* Your own classes/modules
|
||||
* Internal collaborators
|
||||
* Anything you control
|
||||
|
||||
## Designing for Mockability
|
||||
|
||||
At system boundaries, design interfaces that are easy to mock:
|
||||
|
||||
**1. Use dependency injection**
|
||||
|
||||
Pass external dependencies in rather than creating them internally:
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
# Easy to mock
|
||||
def process_payment(order, payment_client):
|
||||
return payment_client.charge(order.total)
|
||||
|
||||
# Hard to mock
|
||||
def process_payment(order):
|
||||
client = StripeClient(os.getenv("STRIPE_KEY"))
|
||||
return client.charge(order.total)
|
||||
|
||||
```
|
||||
|
||||
**2. Prefer SDK-style interfaces over generic fetchers**
|
||||
|
||||
Create specific functions for each external operation instead of one generic function with conditional logic:
|
||||
|
||||
```python
|
||||
import requests
|
||||
|
||||
# GOOD: Each function is independently mockable
|
||||
class UserAPI:
|
||||
def get_user(self, user_id):
|
||||
return requests.get(f"/users/{user_id}")
|
||||
|
||||
def get_orders(self, user_id):
|
||||
return requests.get(f"/users/{user_id}/orders")
|
||||
|
||||
def create_order(self, data):
|
||||
return requests.post("/orders", json=data)
|
||||
|
||||
# BAD: Mocking requires conditional logic inside the mock
|
||||
class GenericAPI:
|
||||
def fetch(self, endpoint, method="GET", data=None):
|
||||
return requests.request(method, endpoint, json=data)
|
||||
|
||||
```
|
||||
|
||||
The SDK approach means:
|
||||
|
||||
* Each mock returns one specific shape
|
||||
* No conditional logic in test setup
|
||||
* Easier to see which endpoints a test exercises
|
||||
* Type safety per endpoint
|
||||
10
.cursor/skills/tdd/refactoring.md
Normal file
10
.cursor/skills/tdd/refactoring.md
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
# Refactor Candidates
|
||||
|
||||
After TDD cycle, look for:
|
||||
|
||||
- **Duplication** → Extract function/class
|
||||
- **Long methods** → Break into private helpers (keep tests on public interface)
|
||||
- **Shallow modules** → Combine or deepen
|
||||
- **Feature envy** → Move logic to where data lives
|
||||
- **Primitive obsession** → Introduce value objects
|
||||
- **Existing code** the new code reveals as problematic
|
||||
60
.cursor/skills/tdd/tests.md
Normal file
60
.cursor/skills/tdd/tests.md
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
# Good and Bad Tests
|
||||
|
||||
## Good Tests
|
||||
|
||||
**Integration-style**: Test through real interfaces, not mocks of internal parts.
|
||||
|
||||
```python
|
||||
# GOOD: Tests observable behavior
|
||||
def test_user_can_checkout_with_valid_cart():
|
||||
cart = create_cart()
|
||||
cart.add(product)
|
||||
result = checkout(cart, payment_method)
|
||||
assert result.status == "confirmed"
|
||||
|
||||
```
|
||||
|
||||
Characteristics:
|
||||
|
||||
* Tests behavior users/callers care about
|
||||
* Uses public API only
|
||||
* Survives internal refactors
|
||||
* Describes WHAT, not HOW
|
||||
* One logical assertion per test
|
||||
|
||||
## Bad Tests
|
||||
|
||||
**Implementation-detail tests**: Coupled to internal structure.
|
||||
|
||||
```python
|
||||
# BAD: Tests implementation details
|
||||
def test_checkout_calls_payment_service_process():
|
||||
mock_payment = MagicMock()
|
||||
checkout(cart, mock_payment)
|
||||
mock_payment.process.assert_called_with(cart.total)
|
||||
|
||||
```
|
||||
|
||||
Red flags:
|
||||
|
||||
* Mocking internal collaborators
|
||||
* Testing private methods
|
||||
* Asserting on call counts/order
|
||||
* Test breaks when refactoring without behavior change
|
||||
* Test name describes HOW not WHAT
|
||||
* Verifying through external means instead of interface
|
||||
|
||||
```python
|
||||
# BAD: Bypasses interface to verify
|
||||
def test_create_user_saves_to_database():
|
||||
create_user({"name": "Alice"})
|
||||
row = db.query("SELECT * FROM users WHERE name = ?", ["Alice"])
|
||||
assert row is not None
|
||||
|
||||
# GOOD: Verifies through interface
|
||||
def test_create_user_makes_user_retrievable():
|
||||
user = create_user({"name": "Alice"})
|
||||
retrieved = get_user(user.id)
|
||||
assert retrieved.name == "Alice"
|
||||
|
||||
```
|
||||
|
|
@ -260,6 +260,10 @@ ENV NEXT_PUBLIC_FASTAPI_BACKEND_URL=http://localhost:8000
|
|||
ENV NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE=LOCAL
|
||||
ENV NEXT_PUBLIC_ETL_SERVICE=DOCLING
|
||||
|
||||
# Daytona Sandbox (cloud code execution — no local server needed)
|
||||
ENV DAYTONA_SANDBOX_ENABLED=FALSE
|
||||
# DAYTONA_API_KEY, DAYTONA_API_URL, DAYTONA_TARGET: set at runtime for production.
|
||||
|
||||
# Electric SQL configuration (ELECTRIC_DATABASE_URL is built dynamically by entrypoint from these values)
|
||||
ENV ELECTRIC_DB_USER=electric
|
||||
ENV ELECTRIC_DB_PASSWORD=electric_password
|
||||
|
|
|
|||
|
|
@ -65,6 +65,11 @@ services:
|
|||
- ELECTRIC_DB_PASSWORD=${ELECTRIC_DB_PASSWORD:-electric_password}
|
||||
- AUTH_TYPE=${AUTH_TYPE:-LOCAL}
|
||||
- NEXT_FRONTEND_URL=${NEXT_FRONTEND_URL:-http://localhost:3000}
|
||||
# Daytona Sandbox – uncomment and set credentials to enable cloud code execution
|
||||
# - DAYTONA_SANDBOX_ENABLED=TRUE
|
||||
# - DAYTONA_API_KEY=${DAYTONA_API_KEY:-}
|
||||
# - DAYTONA_API_URL=${DAYTONA_API_URL:-https://app.daytona.io/api}
|
||||
# - DAYTONA_TARGET=${DAYTONA_TARGET:-us}
|
||||
depends_on:
|
||||
- db
|
||||
- redis
|
||||
|
|
|
|||
|
|
@ -232,6 +232,7 @@ echo " Auth Type: ${NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE}"
|
|||
echo " ETL Service: ${NEXT_PUBLIC_ETL_SERVICE}"
|
||||
echo " TTS Service: ${TTS_SERVICE}"
|
||||
echo " STT Service: ${STT_SERVICE}"
|
||||
echo " Daytona Sandbox: ${DAYTONA_SANDBOX_ENABLED:-FALSE}"
|
||||
echo "==========================================="
|
||||
echo ""
|
||||
|
||||
|
|
|
|||
|
|
@ -167,37 +167,12 @@ LANGSMITH_ENDPOINT=https://api.smith.langchain.com
|
|||
LANGSMITH_API_KEY=lsv2_pt_.....
|
||||
LANGSMITH_PROJECT=surfsense
|
||||
|
||||
# Uvicorn Server Configuration
|
||||
# Full documentation for Uvicorn options can be found at: https://www.uvicorn.org/#command-line-options
|
||||
UVICORN_HOST="0.0.0.0"
|
||||
UVICORN_PORT=8000
|
||||
UVICORN_LOG_LEVEL=info
|
||||
|
||||
# OPTIONAL: Advanced Uvicorn Options (uncomment to use)
|
||||
# UVICORN_PROXY_HEADERS=false
|
||||
# UVICORN_FORWARDED_ALLOW_IPS="127.0.0.1"
|
||||
# UVICORN_WORKERS=1
|
||||
# UVICORN_ACCESS_LOG=true
|
||||
# UVICORN_LOOP="auto"
|
||||
# UVICORN_HTTP="auto"
|
||||
# UVICORN_WS="auto"
|
||||
# UVICORN_LIFESPAN="auto"
|
||||
# UVICORN_LOG_CONFIG=""
|
||||
# UVICORN_SERVER_HEADER=true
|
||||
# UVICORN_DATE_HEADER=true
|
||||
# UVICORN_LIMIT_CONCURRENCY=
|
||||
# UVICORN_LIMIT_MAX_REQUESTS=
|
||||
# UVICORN_TIMEOUT_KEEP_ALIVE=5
|
||||
# UVICORN_TIMEOUT_NOTIFY=30
|
||||
# UVICORN_SSL_KEYFILE=""
|
||||
# UVICORN_SSL_CERTFILE=""
|
||||
# UVICORN_SSL_KEYFILE_PASSWORD=""
|
||||
# UVICORN_SSL_VERSION=""
|
||||
# UVICORN_SSL_CERT_REQS=""
|
||||
# UVICORN_SSL_CA_CERTS=""
|
||||
# UVICORN_SSL_CIPHERS=""
|
||||
# UVICORN_HEADERS=""
|
||||
# UVICORN_USE_COLORS=true
|
||||
# UVICORN_UDS=""
|
||||
# UVICORN_FD=""
|
||||
# UVICORN_ROOT_PATH=""
|
||||
# Agent Specific Configuration
|
||||
# Daytona Sandbox (secure cloud code execution for deep agent)
|
||||
# Set DAYTONA_SANDBOX_ENABLED=TRUE to give the agent an isolated execute tool
|
||||
DAYTONA_SANDBOX_ENABLED=TRUE
|
||||
DAYTONA_API_KEY=dtn_asdasfasfafas
|
||||
DAYTONA_API_URL=https://app.daytona.io/api
|
||||
DAYTONA_TARGET=us
|
||||
# Directory for locally-persisted sandbox files (after sandbox deletion)
|
||||
SANDBOX_FILES_DIR=sandbox_files
|
||||
|
|
|
|||
1
surfsense_backend/.gitignore
vendored
1
surfsense_backend/.gitignore
vendored
|
|
@ -6,6 +6,7 @@ __pycache__/
|
|||
.flashrank_cache
|
||||
surf_new_backend.egg-info/
|
||||
podcasts/
|
||||
sandbox_files/
|
||||
temp_audio/
|
||||
celerybeat-schedule*
|
||||
celerybeat-schedule.*
|
||||
|
|
|
|||
|
|
@ -0,0 +1,46 @@
|
|||
"""102_add_enable_summary_to_connectors
|
||||
|
||||
Revision ID: 102
|
||||
Revises: 101
|
||||
Create Date: 2026-02-26
|
||||
|
||||
Adds enable_summary boolean column to search_source_connectors.
|
||||
Defaults to False for all existing and new connectors so LLM-based
|
||||
summary generation is opt-in.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "102"
|
||||
down_revision: str | None = "101"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
existing_columns = [
|
||||
col["name"] for col in sa.inspect(conn).get_columns("search_source_connectors")
|
||||
]
|
||||
|
||||
if "enable_summary" not in existing_columns:
|
||||
op.add_column(
|
||||
"search_source_connectors",
|
||||
sa.Column(
|
||||
"enable_summary",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("false"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("search_source_connectors", "enable_summary")
|
||||
|
|
@ -6,10 +6,14 @@ with configurable tools via the tools registry and configurable prompts
|
|||
via NewLLMConfig.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from deepagents import create_deep_agent
|
||||
from deepagents.backends.protocol import SandboxBackendProtocol
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.types import Checkpointer
|
||||
|
|
@ -25,6 +29,8 @@ from app.agents.new_chat.tools.registry import build_tools_async
|
|||
from app.db import ChatVisibility
|
||||
from app.services.connector_service import ConnectorService
|
||||
|
||||
_perf_log = logging.getLogger("surfsense.perf")
|
||||
|
||||
# =============================================================================
|
||||
# Connector Type Mapping
|
||||
# =============================================================================
|
||||
|
|
@ -128,6 +134,7 @@ async def create_surfsense_deep_agent(
|
|||
additional_tools: Sequence[BaseTool] | None = None,
|
||||
firecrawl_api_key: str | None = None,
|
||||
thread_visibility: ChatVisibility | None = None,
|
||||
sandbox_backend: SandboxBackendProtocol | None = None,
|
||||
):
|
||||
"""
|
||||
Create a SurfSense deep agent with configurable tools and prompts.
|
||||
|
|
@ -167,6 +174,9 @@ async def create_surfsense_deep_agent(
|
|||
These are always added regardless of enabled/disabled settings.
|
||||
firecrawl_api_key: Optional Firecrawl API key for premium web scraping.
|
||||
Falls back to Chromium/Trafilatura if not provided.
|
||||
sandbox_backend: Optional sandbox backend (e.g. DaytonaSandbox) for
|
||||
secure code execution. When provided, the agent gets an
|
||||
isolated ``execute`` tool for running shell commands.
|
||||
|
||||
Returns:
|
||||
CompiledStateGraph: The configured deep agent
|
||||
|
|
@ -205,32 +215,41 @@ async def create_surfsense_deep_agent(
|
|||
additional_tools=[my_custom_tool]
|
||||
)
|
||||
"""
|
||||
_t_agent_total = time.perf_counter()
|
||||
|
||||
# Discover available connectors and document types for this search space
|
||||
# This enables dynamic tool docstrings that inform the LLM about what's actually available
|
||||
available_connectors: list[str] | None = None
|
||||
available_document_types: list[str] | None = None
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
try:
|
||||
# Get enabled search source connectors for this search space
|
||||
connector_types = await connector_service.get_available_connectors(
|
||||
search_space_id
|
||||
)
|
||||
if connector_types:
|
||||
# Convert enum values to strings and also include mapped document types
|
||||
available_connectors = _map_connectors_to_searchable_types(connector_types)
|
||||
|
||||
# Get document types that have at least one document indexed
|
||||
available_document_types = await connector_service.get_available_document_types(
|
||||
search_space_id
|
||||
)
|
||||
except Exception as e:
|
||||
# Log but don't fail - fall back to all connectors if discovery fails
|
||||
import logging
|
||||
|
||||
logging.warning(f"Failed to discover available connectors/document types: {e}")
|
||||
_perf_log.info(
|
||||
"[create_agent] Connector/doc-type discovery in %.3fs",
|
||||
time.perf_counter() - _t0,
|
||||
)
|
||||
|
||||
# Build dependencies dict for the tools registry
|
||||
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||
|
||||
# Extract the model's context window so tools can size their output.
|
||||
_model_profile = getattr(llm, "profile", None)
|
||||
_max_input_tokens: int | None = (
|
||||
_model_profile.get("max_input_tokens")
|
||||
if isinstance(_model_profile, dict)
|
||||
else None
|
||||
)
|
||||
|
||||
dependencies = {
|
||||
"search_space_id": search_space_id,
|
||||
"db_session": db_session,
|
||||
|
|
@ -241,6 +260,7 @@ async def create_surfsense_deep_agent(
|
|||
"thread_visibility": visibility,
|
||||
"available_connectors": available_connectors,
|
||||
"available_document_types": available_document_types,
|
||||
"max_input_tokens": _max_input_tokens,
|
||||
}
|
||||
|
||||
# Disable Notion action tools if no Notion connector is configured
|
||||
|
|
@ -269,35 +289,61 @@ async def create_surfsense_deep_agent(
|
|||
modified_disabled_tools.extend(linear_tools)
|
||||
|
||||
# Build tools using the async registry (includes MCP tools)
|
||||
_t0 = time.perf_counter()
|
||||
tools = await build_tools_async(
|
||||
dependencies=dependencies,
|
||||
enabled_tools=enabled_tools,
|
||||
disabled_tools=modified_disabled_tools,
|
||||
additional_tools=list(additional_tools) if additional_tools else None,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[create_agent] build_tools_async in %.3fs (%d tools)",
|
||||
time.perf_counter() - _t0,
|
||||
len(tools),
|
||||
)
|
||||
|
||||
# Build system prompt based on agent_config
|
||||
_t0 = time.perf_counter()
|
||||
_sandbox_enabled = sandbox_backend is not None
|
||||
if agent_config is not None:
|
||||
# Use configurable prompt with settings from NewLLMConfig
|
||||
system_prompt = build_configurable_system_prompt(
|
||||
custom_system_instructions=agent_config.system_instructions,
|
||||
use_default_system_instructions=agent_config.use_default_system_instructions,
|
||||
citations_enabled=agent_config.citations_enabled,
|
||||
thread_visibility=thread_visibility,
|
||||
sandbox_enabled=_sandbox_enabled,
|
||||
)
|
||||
else:
|
||||
system_prompt = build_surfsense_system_prompt(
|
||||
thread_visibility=thread_visibility,
|
||||
sandbox_enabled=_sandbox_enabled,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0
|
||||
)
|
||||
|
||||
# Create the deep agent with system prompt and checkpointer
|
||||
# Note: TodoListMiddleware (write_todos) is included by default in create_deep_agent
|
||||
agent = create_deep_agent(
|
||||
# Build optional kwargs for the deep agent
|
||||
deep_agent_kwargs: dict[str, Any] = {}
|
||||
if sandbox_backend is not None:
|
||||
deep_agent_kwargs["backend"] = sandbox_backend
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
agent = await asyncio.to_thread(
|
||||
create_deep_agent,
|
||||
model=llm,
|
||||
tools=tools,
|
||||
system_prompt=system_prompt,
|
||||
context_schema=SurfSenseContextSchema,
|
||||
checkpointer=checkpointer,
|
||||
**deep_agent_kwargs,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[create_agent] Graph compiled (create_deep_agent) in %.3fs",
|
||||
time.perf_counter() - _t0,
|
||||
)
|
||||
|
||||
_perf_log.info(
|
||||
"[create_agent] Total agent creation in %.3fs",
|
||||
time.perf_counter() - _t_agent_total,
|
||||
)
|
||||
return agent
|
||||
|
|
|
|||
275
surfsense_backend/app/agents/new_chat/sandbox.py
Normal file
275
surfsense_backend/app/agents/new_chat/sandbox.py
Normal file
|
|
@ -0,0 +1,275 @@
|
|||
"""
|
||||
Daytona sandbox provider for SurfSense deep agent.
|
||||
|
||||
Manages the lifecycle of sandboxed code execution environments.
|
||||
Each conversation thread gets its own isolated sandbox instance
|
||||
via the Daytona cloud API, identified by labels.
|
||||
|
||||
Files created during a session are persisted to local storage before
|
||||
the sandbox is deleted so they remain downloadable after cleanup.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from daytona import (
|
||||
CreateSandboxFromSnapshotParams,
|
||||
Daytona,
|
||||
DaytonaConfig,
|
||||
SandboxState,
|
||||
)
|
||||
from daytona.common.errors import DaytonaError
|
||||
from deepagents.backends.protocol import ExecuteResponse
|
||||
from langchain_daytona import DaytonaSandbox
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _TimeoutAwareSandbox(DaytonaSandbox):
|
||||
"""DaytonaSandbox subclass that accepts the per-command *timeout*
|
||||
kwarg required by the deepagents middleware.
|
||||
|
||||
The upstream ``langchain-daytona`` ``execute()`` ignores timeout,
|
||||
so deepagents raises *"This sandbox backend does not support
|
||||
per-command timeout overrides"* on every first call. This thin
|
||||
wrapper forwards the parameter to the Daytona SDK.
|
||||
"""
|
||||
|
||||
def execute(self, command: str, *, timeout: int | None = None) -> ExecuteResponse:
|
||||
t = timeout if timeout is not None else self._timeout
|
||||
result = self._sandbox.process.exec(command, timeout=t)
|
||||
return ExecuteResponse(
|
||||
output=result.result,
|
||||
exit_code=result.exit_code,
|
||||
truncated=False,
|
||||
)
|
||||
|
||||
async def aexecute(
|
||||
self, command: str, *, timeout: int | None = None
|
||||
) -> ExecuteResponse: # type: ignore[override]
|
||||
return await asyncio.to_thread(self.execute, command, timeout=timeout)
|
||||
|
||||
|
||||
_daytona_client: Daytona | None = None
|
||||
_sandbox_cache: dict[str, _TimeoutAwareSandbox] = {}
|
||||
THREAD_LABEL_KEY = "surfsense_thread"
|
||||
|
||||
|
||||
def is_sandbox_enabled() -> bool:
|
||||
return os.environ.get("DAYTONA_SANDBOX_ENABLED", "FALSE").upper() == "TRUE"
|
||||
|
||||
|
||||
def _get_client() -> Daytona:
|
||||
global _daytona_client
|
||||
if _daytona_client is None:
|
||||
config = DaytonaConfig(
|
||||
api_key=os.environ.get("DAYTONA_API_KEY", ""),
|
||||
api_url=os.environ.get("DAYTONA_API_URL", "https://app.daytona.io/api"),
|
||||
target=os.environ.get("DAYTONA_TARGET", "us"),
|
||||
)
|
||||
_daytona_client = Daytona(config)
|
||||
return _daytona_client
|
||||
|
||||
|
||||
def _find_or_create(thread_id: str) -> _TimeoutAwareSandbox:
|
||||
"""Find an existing sandbox for *thread_id*, or create a new one.
|
||||
|
||||
If an existing sandbox is found but is stopped/archived, it will be
|
||||
restarted automatically before returning.
|
||||
"""
|
||||
client = _get_client()
|
||||
labels = {THREAD_LABEL_KEY: thread_id}
|
||||
|
||||
try:
|
||||
sandbox = client.find_one(labels=labels)
|
||||
logger.info("Found existing sandbox %s (state=%s)", sandbox.id, sandbox.state)
|
||||
|
||||
if sandbox.state in (
|
||||
SandboxState.STOPPED,
|
||||
SandboxState.STOPPING,
|
||||
SandboxState.ARCHIVED,
|
||||
):
|
||||
logger.info("Starting stopped sandbox %s …", sandbox.id)
|
||||
sandbox.start(timeout=60)
|
||||
logger.info("Sandbox %s is now started", sandbox.id)
|
||||
elif sandbox.state in (
|
||||
SandboxState.ERROR,
|
||||
SandboxState.BUILD_FAILED,
|
||||
SandboxState.DESTROYED,
|
||||
):
|
||||
logger.warning(
|
||||
"Sandbox %s in unrecoverable state %s — creating a new one",
|
||||
sandbox.id,
|
||||
sandbox.state,
|
||||
)
|
||||
sandbox = client.create(
|
||||
CreateSandboxFromSnapshotParams(language="python", labels=labels)
|
||||
)
|
||||
logger.info("Created replacement sandbox: %s", sandbox.id)
|
||||
elif sandbox.state != SandboxState.STARTED:
|
||||
sandbox.wait_for_sandbox_start(timeout=60)
|
||||
|
||||
except Exception:
|
||||
logger.info("No existing sandbox for thread %s — creating one", thread_id)
|
||||
sandbox = client.create(
|
||||
CreateSandboxFromSnapshotParams(language="python", labels=labels)
|
||||
)
|
||||
logger.info("Created new sandbox: %s", sandbox.id)
|
||||
|
||||
return _TimeoutAwareSandbox(sandbox=sandbox)
|
||||
|
||||
|
||||
async def get_or_create_sandbox(thread_id: int | str) -> _TimeoutAwareSandbox:
|
||||
"""Get or create a sandbox for a conversation thread.
|
||||
|
||||
Uses an in-process cache keyed by thread_id so subsequent messages
|
||||
in the same conversation reuse the sandbox object without an API call.
|
||||
|
||||
Args:
|
||||
thread_id: The conversation thread identifier.
|
||||
|
||||
Returns:
|
||||
DaytonaSandbox connected to the sandbox.
|
||||
"""
|
||||
key = str(thread_id)
|
||||
cached = _sandbox_cache.get(key)
|
||||
if cached is not None:
|
||||
logger.info("Reusing cached sandbox for thread %s", key)
|
||||
return cached
|
||||
sandbox = await asyncio.to_thread(_find_or_create, key)
|
||||
_sandbox_cache[key] = sandbox
|
||||
return sandbox
|
||||
|
||||
|
||||
async def delete_sandbox(thread_id: int | str) -> None:
|
||||
"""Delete the sandbox for a conversation thread."""
|
||||
_sandbox_cache.pop(str(thread_id), None)
|
||||
|
||||
def _delete() -> None:
|
||||
client = _get_client()
|
||||
labels = {THREAD_LABEL_KEY: str(thread_id)}
|
||||
try:
|
||||
sandbox = client.find_one(labels=labels)
|
||||
except DaytonaError:
|
||||
logger.debug(
|
||||
"No sandbox to delete for thread %s (already removed)", thread_id
|
||||
)
|
||||
return
|
||||
try:
|
||||
client.delete(sandbox)
|
||||
logger.info("Sandbox deleted: %s", sandbox.id)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to delete sandbox for thread %s",
|
||||
thread_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
await asyncio.to_thread(_delete)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Local file persistence
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _get_sandbox_files_dir() -> Path:
|
||||
return Path(os.environ.get("SANDBOX_FILES_DIR", "sandbox_files"))
|
||||
|
||||
|
||||
def _local_path_for(thread_id: int | str, sandbox_path: str) -> Path:
|
||||
"""Map a sandbox-internal absolute path to a local filesystem path."""
|
||||
relative = sandbox_path.lstrip("/")
|
||||
return _get_sandbox_files_dir() / str(thread_id) / relative
|
||||
|
||||
|
||||
def get_local_sandbox_file(thread_id: int | str, sandbox_path: str) -> bytes | None:
|
||||
"""Read a previously-persisted sandbox file from local storage.
|
||||
|
||||
Returns the file bytes, or *None* if the file does not exist locally.
|
||||
"""
|
||||
local = _local_path_for(thread_id, sandbox_path)
|
||||
if local.is_file():
|
||||
return local.read_bytes()
|
||||
return None
|
||||
|
||||
|
||||
def delete_local_sandbox_files(thread_id: int | str) -> None:
|
||||
"""Remove all locally-persisted sandbox files for a thread."""
|
||||
thread_dir = _get_sandbox_files_dir() / str(thread_id)
|
||||
if thread_dir.is_dir():
|
||||
shutil.rmtree(thread_dir, ignore_errors=True)
|
||||
logger.info("Deleted local sandbox files for thread %s", thread_id)
|
||||
|
||||
|
||||
async def persist_and_delete_sandbox(
|
||||
thread_id: int | str,
|
||||
sandbox_file_paths: list[str],
|
||||
) -> None:
|
||||
"""Download sandbox files to local storage, then delete the sandbox.
|
||||
|
||||
Each file in *sandbox_file_paths* is downloaded from the Daytona
|
||||
sandbox and saved under ``{SANDBOX_FILES_DIR}/{thread_id}/…``.
|
||||
Per-file errors are logged but do **not** prevent the sandbox from
|
||||
being deleted — freeing Daytona storage is the priority.
|
||||
"""
|
||||
_sandbox_cache.pop(str(thread_id), None)
|
||||
|
||||
def _persist_and_delete() -> None:
|
||||
client = _get_client()
|
||||
labels = {THREAD_LABEL_KEY: str(thread_id)}
|
||||
|
||||
try:
|
||||
sandbox = client.find_one(labels=labels)
|
||||
except Exception:
|
||||
logger.info(
|
||||
"No sandbox found for thread %s — nothing to persist", thread_id
|
||||
)
|
||||
return
|
||||
|
||||
# Ensure the sandbox is running so we can download files
|
||||
if sandbox.state != SandboxState.STARTED:
|
||||
try:
|
||||
sandbox.start(timeout=60)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Could not start sandbox %s for file download — deleting anyway",
|
||||
sandbox.id,
|
||||
exc_info=True,
|
||||
)
|
||||
with contextlib.suppress(Exception):
|
||||
client.delete(sandbox)
|
||||
return
|
||||
|
||||
for path in sandbox_file_paths:
|
||||
try:
|
||||
content: bytes = sandbox.fs.download_file(path)
|
||||
local = _local_path_for(thread_id, path)
|
||||
local.parent.mkdir(parents=True, exist_ok=True)
|
||||
local.write_bytes(content)
|
||||
logger.info("Persisted sandbox file %s → %s", path, local)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist sandbox file %s for thread %s",
|
||||
path,
|
||||
thread_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
try:
|
||||
client.delete(sandbox)
|
||||
logger.info("Sandbox deleted after file persistence: %s", sandbox.id)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to delete sandbox %s after persistence",
|
||||
sandbox.id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
await asyncio.to_thread(_persist_and_delete)
|
||||
|
|
@ -645,6 +645,87 @@ However, from your video learning, it's important to note that asyncio is not su
|
|||
</citation_instructions>
|
||||
"""
|
||||
|
||||
# Sandbox / code execution instructions — appended when sandbox backend is enabled.
|
||||
# Inspired by Claude's computer-use prompt, scoped to code execution & data analytics.
|
||||
SANDBOX_EXECUTION_INSTRUCTIONS = """
|
||||
<code_execution>
|
||||
You have access to a secure, isolated Linux sandbox environment for running code and shell commands.
|
||||
This gives you the `execute` tool alongside the standard filesystem tools (`ls`, `read_file`, `write_file`, `edit_file`, `glob`, `grep`).
|
||||
|
||||
## CRITICAL — CODE-FIRST RULE
|
||||
|
||||
ALWAYS prefer executing code over giving a text-only response when the user's request involves ANY of the following:
|
||||
- **Creating a chart, plot, graph, or visualization** → Write Python code and generate the actual file. NEVER describe percentages or data in text and offer to "paste into Excel". Just produce the chart.
|
||||
- **Data analysis, statistics, or computation** → Write code to compute the answer. Do not do math by hand in text.
|
||||
- **Generating or transforming files** (CSV, PDF, images, etc.) → Write code to create the file.
|
||||
- **Running, testing, or debugging code** → Execute it in the sandbox.
|
||||
|
||||
This applies even when you first retrieve data from the knowledge base. After `search_knowledge_base` returns relevant data, **immediately proceed to write and execute code** if the user's request matches any of the categories above. Do NOT stop at a text summary and wait for the user to ask you to "use Python" — that extra round-trip is a poor experience.
|
||||
|
||||
Example (CORRECT):
|
||||
User: "Create a pie chart of my benefits"
|
||||
→ 1. search_knowledge_base → retrieve benefits data
|
||||
→ 2. Immediately execute Python code (matplotlib) to generate the pie chart
|
||||
→ 3. Return the downloadable file + brief description
|
||||
|
||||
Example (WRONG):
|
||||
User: "Create a pie chart of my benefits"
|
||||
→ 1. search_knowledge_base → retrieve benefits data
|
||||
→ 2. Print a text table with percentages and ask the user if they want a chart ← NEVER do this
|
||||
|
||||
## When to Use Code Execution
|
||||
|
||||
Use the sandbox when the task benefits from actually running code rather than just describing it:
|
||||
- **Data analysis**: Load CSVs/JSON, compute statistics, filter/aggregate data, pivot tables
|
||||
- **Visualization**: Generate charts and plots (matplotlib, plotly, seaborn)
|
||||
- **Calculations**: Math, financial modeling, unit conversions, simulations
|
||||
- **Code validation**: Run and test code snippets the user provides or asks about
|
||||
- **File processing**: Parse, transform, or convert data files
|
||||
- **Quick prototyping**: Demonstrate working code for the user's problem
|
||||
- **Package exploration**: Install and test libraries the user is evaluating
|
||||
|
||||
## When NOT to Use Code Execution
|
||||
|
||||
Do not use the sandbox for:
|
||||
- Answering factual questions from your own knowledge
|
||||
- Summarizing or explaining concepts
|
||||
- Simple formatting or text generation tasks
|
||||
- Tasks that don't require running code to answer
|
||||
|
||||
## Package Management
|
||||
|
||||
- Use `pip install <package>` to install Python packages as needed
|
||||
- Common data/analytics packages (pandas, numpy, matplotlib, scipy, scikit-learn) may need to be installed on first use
|
||||
- Always verify a package installed successfully before using it
|
||||
|
||||
## Working Guidelines
|
||||
|
||||
- **Working directory**: The shell starts in the sandbox user's home directory (e.g. `/home/daytona`). Use **relative paths** or `/tmp/` for all files you create. NEVER write directly to `/home/` — that is the parent directory and is not writable. Use `pwd` if you need to discover the current working directory.
|
||||
- **Iterative approach**: For complex tasks, break work into steps — write code, run it, check output, refine
|
||||
- **Error handling**: If code fails, read the error, fix the issue, and retry. Don't just report the error without attempting a fix.
|
||||
- **Show results**: When generating plots or outputs, present the key findings directly in your response. For plots, save to a file and describe the results.
|
||||
- **Be efficient**: Install packages once per session. Combine related commands when possible.
|
||||
- **Large outputs**: If command output is very large, use `head`, `tail`, or save to a file and read selectively.
|
||||
|
||||
## Sharing Generated Files
|
||||
|
||||
When your code creates output files (images, CSVs, PDFs, etc.) in the sandbox:
|
||||
- **Print the absolute path** at the end of your script so the user can download the file. Example: `print("SANDBOX_FILE: /tmp/chart.png")`
|
||||
- **DO NOT call `display_image`** for files created inside the sandbox. Sandbox files are not accessible via public URLs, so `display_image` will always show "Image not available". The frontend automatically renders a download button from the `SANDBOX_FILE:` marker.
|
||||
- You can output multiple files, one per line: `print("SANDBOX_FILE: /tmp/report.csv")`, `print("SANDBOX_FILE: /tmp/chart.png")`
|
||||
- Always describe what the file contains in your response text so the user knows what they are downloading.
|
||||
- IMPORTANT: Every `execute` call that saves a file MUST print the `SANDBOX_FILE: <path>` marker. Without it the user cannot download the file.
|
||||
|
||||
## Data Analytics Best Practices
|
||||
|
||||
When the user asks you to analyze data:
|
||||
1. First, inspect the data structure (`head`, `shape`, `dtypes`, `describe()`)
|
||||
2. Clean and validate before computing (handle nulls, check types)
|
||||
3. Perform the analysis and present results clearly
|
||||
4. Offer follow-up insights or visualizations when appropriate
|
||||
</code_execution>
|
||||
"""
|
||||
|
||||
# Anti-citation prompt - used when citations are disabled
|
||||
# This explicitly tells the model NOT to include citations
|
||||
SURFSENSE_NO_CITATION_INSTRUCTIONS = """
|
||||
|
|
@ -670,6 +751,7 @@ Your goal is to provide helpful, informative answers in a clean, readable format
|
|||
def build_surfsense_system_prompt(
|
||||
today: datetime | None = None,
|
||||
thread_visibility: ChatVisibility | None = None,
|
||||
sandbox_enabled: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Build the SurfSense system prompt with default settings.
|
||||
|
|
@ -678,10 +760,12 @@ def build_surfsense_system_prompt(
|
|||
- Default system instructions
|
||||
- Tools instructions (always included)
|
||||
- Citation instructions enabled
|
||||
- Sandbox execution instructions (when sandbox_enabled=True)
|
||||
|
||||
Args:
|
||||
today: Optional datetime for today's date (defaults to current UTC date)
|
||||
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
|
||||
sandbox_enabled: Whether the sandbox backend is active (adds code execution instructions).
|
||||
|
||||
Returns:
|
||||
Complete system prompt string
|
||||
|
|
@ -691,7 +775,13 @@ def build_surfsense_system_prompt(
|
|||
system_instructions = _get_system_instructions(visibility, today)
|
||||
tools_instructions = _get_tools_instructions(visibility)
|
||||
citation_instructions = SURFSENSE_CITATION_INSTRUCTIONS
|
||||
return system_instructions + tools_instructions + citation_instructions
|
||||
sandbox_instructions = SANDBOX_EXECUTION_INSTRUCTIONS if sandbox_enabled else ""
|
||||
return (
|
||||
system_instructions
|
||||
+ tools_instructions
|
||||
+ citation_instructions
|
||||
+ sandbox_instructions
|
||||
)
|
||||
|
||||
|
||||
def build_configurable_system_prompt(
|
||||
|
|
@ -700,14 +790,16 @@ def build_configurable_system_prompt(
|
|||
citations_enabled: bool = True,
|
||||
today: datetime | None = None,
|
||||
thread_visibility: ChatVisibility | None = None,
|
||||
sandbox_enabled: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Build a configurable SurfSense system prompt based on NewLLMConfig settings.
|
||||
|
||||
The prompt is composed of three parts:
|
||||
The prompt is composed of up to four parts:
|
||||
1. System Instructions - either custom or default SURFSENSE_SYSTEM_INSTRUCTIONS
|
||||
2. Tools Instructions - always included (SURFSENSE_TOOLS_INSTRUCTIONS)
|
||||
3. Citation Instructions - either SURFSENSE_CITATION_INSTRUCTIONS or SURFSENSE_NO_CITATION_INSTRUCTIONS
|
||||
4. Sandbox Execution Instructions - when sandbox_enabled=True
|
||||
|
||||
Args:
|
||||
custom_system_instructions: Custom system instructions to use. If empty/None and
|
||||
|
|
@ -719,6 +811,7 @@ def build_configurable_system_prompt(
|
|||
anti-citation instructions (False).
|
||||
today: Optional datetime for today's date (defaults to current UTC date)
|
||||
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
|
||||
sandbox_enabled: Whether the sandbox backend is active (adds code execution instructions).
|
||||
|
||||
Returns:
|
||||
Complete system prompt string
|
||||
|
|
@ -727,7 +820,6 @@ def build_configurable_system_prompt(
|
|||
|
||||
# Determine system instructions
|
||||
if custom_system_instructions and custom_system_instructions.strip():
|
||||
# Use custom instructions, injecting the date placeholder if present
|
||||
system_instructions = custom_system_instructions.format(
|
||||
resolved_today=resolved_today
|
||||
)
|
||||
|
|
@ -735,7 +827,6 @@ def build_configurable_system_prompt(
|
|||
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||
system_instructions = _get_system_instructions(visibility, today)
|
||||
else:
|
||||
# No system instructions (edge case)
|
||||
system_instructions = ""
|
||||
|
||||
# Tools instructions: conditional on thread_visibility (private vs shared memory wording)
|
||||
|
|
@ -748,7 +839,14 @@ def build_configurable_system_prompt(
|
|||
else SURFSENSE_NO_CITATION_INSTRUCTIONS
|
||||
)
|
||||
|
||||
return system_instructions + tools_instructions + citation_instructions
|
||||
sandbox_instructions = SANDBOX_EXECUTION_INSTRUCTIONS if sandbox_enabled else ""
|
||||
|
||||
return (
|
||||
system_instructions
|
||||
+ tools_instructions
|
||||
+ citation_instructions
|
||||
+ sandbox_instructions
|
||||
)
|
||||
|
||||
|
||||
def get_default_system_instructions() -> str:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,11 @@
|
|||
from app.agents.new_chat.tools.google_drive.create_file import (
|
||||
create_create_google_drive_file_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.google_drive.trash_file import (
|
||||
create_delete_google_drive_file_tool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"create_create_google_drive_file_tool",
|
||||
"create_delete_google_drive_file_tool",
|
||||
]
|
||||
|
|
@ -0,0 +1,239 @@
|
|||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from googleapiclient.errors import HttpError
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.connectors.google_drive.client import GoogleDriveClient
|
||||
from app.connectors.google_drive.file_types import GOOGLE_DOC, GOOGLE_SHEET
|
||||
from app.services.google_drive import GoogleDriveToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MIME_MAP: dict[str, str] = {
|
||||
"google_doc": GOOGLE_DOC,
|
||||
"google_sheet": GOOGLE_SHEET,
|
||||
}
|
||||
|
||||
|
||||
def create_create_google_drive_file_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def create_google_drive_file(
|
||||
name: str,
|
||||
file_type: Literal["google_doc", "google_sheet"],
|
||||
content: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new Google Doc or Google Sheet in Google Drive.
|
||||
|
||||
Use this tool when the user explicitly asks to create a new document
|
||||
or spreadsheet in Google Drive.
|
||||
|
||||
Args:
|
||||
name: The file name (without extension).
|
||||
file_type: Either "google_doc" or "google_sheet".
|
||||
content: Optional initial content. For google_doc, provide markdown text.
|
||||
For google_sheet, provide CSV-formatted text.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: "success", "rejected", or "error"
|
||||
- file_id: Google Drive file ID (if success)
|
||||
- name: File name (if success)
|
||||
- web_view_link: URL to open the file (if success)
|
||||
- message: Result message
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined the action.
|
||||
Respond with a brief acknowledgment and do NOT retry or suggest alternatives.
|
||||
- If status is "insufficient_permissions", the connector lacks the required OAuth scope.
|
||||
Inform the user they need to re-authenticate and do NOT retry the action.
|
||||
|
||||
Examples:
|
||||
- "Create a Google Doc called 'Meeting Notes'"
|
||||
- "Create a spreadsheet named 'Budget 2026' with some sample data"
|
||||
"""
|
||||
logger.info(
|
||||
f"create_google_drive_file called: name='{name}', type='{file_type}'"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Google Drive tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
if file_type not in _MIME_MAP:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Unsupported file type '{file_type}'. Use 'google_doc' or 'google_sheet'.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = GoogleDriveToolMetadataService(db_session)
|
||||
context = await metadata_service.get_creation_context(
|
||||
search_space_id, user_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for creating Google Drive file: name='{name}', type='{file_type}'"
|
||||
)
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "google_drive_file_creation",
|
||||
"action": {
|
||||
"tool": "create_google_drive_file",
|
||||
"params": {
|
||||
"name": name,
|
||||
"file_type": file_type,
|
||||
"content": content,
|
||||
"connector_id": None,
|
||||
"parent_folder_id": None,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
logger.warning("No approval decision received")
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The file was not created. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_name = final_params.get("name", name)
|
||||
final_file_type = final_params.get("file_type", file_type)
|
||||
final_content = final_params.get("content", content)
|
||||
final_connector_id = final_params.get("connector_id")
|
||||
final_parent_folder_id = final_params.get("parent_folder_id")
|
||||
|
||||
if not final_name or not final_name.strip():
|
||||
return {"status": "error", "message": "File name cannot be empty."}
|
||||
|
||||
mime_type = _MIME_MAP.get(final_file_type)
|
||||
if not mime_type:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Unsupported file type '{final_file_type}'.",
|
||||
}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
if final_connector_id is not None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Google Drive connector is invalid or has been disconnected.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
else:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Google Drive connector found. Please connect Google Drive in your workspace settings.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
|
||||
logger.info(
|
||||
f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}"
|
||||
)
|
||||
client = GoogleDriveClient(
|
||||
session=db_session, connector_id=actual_connector_id
|
||||
)
|
||||
try:
|
||||
created = await client.create_file(
|
||||
name=final_name,
|
||||
mime_type=mime_type,
|
||||
parent_folder_id=final_parent_folder_id,
|
||||
content=final_content,
|
||||
)
|
||||
except HttpError as http_err:
|
||||
if http_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {http_err}"
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Drive account needs additional permissions. Please re-authenticate.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Google Drive file created: id={created.get('id')}, name={created.get('name')}"
|
||||
)
|
||||
return {
|
||||
"status": "success",
|
||||
"file_id": created.get("id"),
|
||||
"name": created.get("name"),
|
||||
"web_view_link": created.get("webViewLink"),
|
||||
"message": f"Successfully created '{created.get('name')}' in Google Drive.",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
|
||||
logger.error(f"Error creating Google Drive file: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while creating the file. Please try again.",
|
||||
}
|
||||
|
||||
return create_google_drive_file
|
||||
|
|
@ -0,0 +1,243 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from googleapiclient.errors import HttpError
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.connectors.google_drive.client import GoogleDriveClient
|
||||
from app.services.google_drive import GoogleDriveToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_delete_google_drive_file_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def delete_google_drive_file(
|
||||
file_name: str,
|
||||
delete_from_kb: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Move a Google Drive file to trash.
|
||||
|
||||
Use this tool when the user explicitly asks to delete, remove, or trash
|
||||
a file in Google Drive.
|
||||
|
||||
Args:
|
||||
file_name: The exact name of the file to trash (as it appears in Drive).
|
||||
delete_from_kb: Whether to also remove the file from the knowledge base.
|
||||
Default is False.
|
||||
Set to True to remove from both Google Drive and knowledge base.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: "success", "rejected", "not_found", or "error"
|
||||
- file_id: Google Drive file ID (if success)
|
||||
- deleted_from_kb: whether the document was removed from the knowledge base
|
||||
- message: Result message
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined. Respond with a brief
|
||||
acknowledgment and do NOT retry or suggest alternatives.
|
||||
- If status is "not_found", relay the exact message to the user and ask them
|
||||
to verify the file name or check if it has been indexed.
|
||||
- If status is "insufficient_permissions", the connector lacks the required OAuth scope.
|
||||
Inform the user they need to re-authenticate and do NOT retry this tool.
|
||||
|
||||
Examples:
|
||||
- "Delete the 'Meeting Notes' file from Google Drive"
|
||||
- "Trash the 'Old Budget' spreadsheet"
|
||||
"""
|
||||
logger.info(
|
||||
f"delete_google_drive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Google Drive tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = GoogleDriveToolMetadataService(db_session)
|
||||
context = await metadata_service.get_trash_context(
|
||||
search_space_id, user_id, file_name
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"File not found: {error_msg}")
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
logger.error(f"Failed to fetch trash context: {error_msg}")
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
file = context["file"]
|
||||
file_id = file["file_id"]
|
||||
document_id = file.get("document_id")
|
||||
connector_id_from_context = context["account"]["id"]
|
||||
|
||||
if not file_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "File ID is missing from the indexed document. Please re-index the file and try again.",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for deleting Google Drive file: '{file_name}' (file_id={file_id}, delete_from_kb={delete_from_kb})"
|
||||
)
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "google_drive_file_trash",
|
||||
"action": {
|
||||
"tool": "delete_google_drive_file",
|
||||
"params": {
|
||||
"file_id": file_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
logger.warning("No approval decision received")
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The file was not trashed. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
edited_action = decision.get("edited_action")
|
||||
final_params: dict[str, Any] = {}
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_file_id = final_params.get("file_id", file_id)
|
||||
final_connector_id = final_params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_delete_from_kb = final_params.get("delete_from_kb", delete_from_kb)
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this file.",
|
||||
}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Google Drive connector is invalid or has been disconnected.",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Deleting Google Drive file: file_id='{final_file_id}', connector={final_connector_id}"
|
||||
)
|
||||
client = GoogleDriveClient(session=db_session, connector_id=connector.id)
|
||||
try:
|
||||
await client.trash_file(file_id=final_file_id)
|
||||
except HttpError as http_err:
|
||||
if http_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {connector.id}: {http_err}"
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": connector.id,
|
||||
"message": "This Google Drive account needs additional permissions. Please re-authenticate.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Google Drive file deleted (moved to trash): file_id={final_file_id}"
|
||||
)
|
||||
|
||||
trash_result: dict[str, Any] = {
|
||||
"status": "success",
|
||||
"file_id": final_file_id,
|
||||
"message": f"Successfully moved '{file['name']}' to trash.",
|
||||
}
|
||||
|
||||
deleted_from_kb = False
|
||||
if final_delete_from_kb and document_id:
|
||||
try:
|
||||
from app.db import Document
|
||||
|
||||
doc_result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
if document:
|
||||
await db_session.delete(document)
|
||||
await db_session.commit()
|
||||
deleted_from_kb = True
|
||||
logger.info(
|
||||
f"Deleted document {document_id} from knowledge base"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Document {document_id} not found in KB")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document from KB: {e}")
|
||||
await db_session.rollback()
|
||||
trash_result["warning"] = (
|
||||
f"File moved to trash, but failed to remove from knowledge base: {e!s}"
|
||||
)
|
||||
|
||||
trash_result["deleted_from_kb"] = deleted_from_kb
|
||||
if deleted_from_kb:
|
||||
trash_result["message"] = (
|
||||
f"{trash_result.get('message', '')} (also removed from knowledge base)"
|
||||
)
|
||||
|
||||
return trash_result
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
|
||||
logger.error(f"Error deleting Google Drive file: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while trashing the file. Please try again.",
|
||||
}
|
||||
|
||||
return delete_google_drive_file
|
||||
|
|
@ -172,12 +172,52 @@ def _normalize_connectors(
|
|||
# =============================================================================
|
||||
|
||||
|
||||
def format_documents_for_context(documents: list[dict[str, Any]]) -> str:
|
||||
# Fraction of the model's context window (in characters) that a single tool
|
||||
# result is allowed to occupy. The remainder is reserved for system prompt,
|
||||
# conversation history, and model output. With ~4 chars/token this gives a
|
||||
# tool result ≈ 25 % of the context budget in tokens.
|
||||
_TOOL_OUTPUT_CONTEXT_FRACTION = 0.25
|
||||
_CHARS_PER_TOKEN = 4
|
||||
|
||||
# Hard-floor / ceiling so the budget is always sensible regardless of what
|
||||
# the model reports.
|
||||
_MIN_TOOL_OUTPUT_CHARS = 20_000 # ~5K tokens
|
||||
_MAX_TOOL_OUTPUT_CHARS = 400_000 # ~100K tokens
|
||||
_MAX_CHUNK_CHARS = 8_000
|
||||
|
||||
|
||||
def _compute_tool_output_budget(max_input_tokens: int | None) -> int:
|
||||
"""Derive a character budget from the model's context window.
|
||||
|
||||
Uses ``litellm.get_model_info`` via the value already resolved by
|
||||
``ChatLiteLLMRouter`` / ``ChatLiteLLM`` and passed through the dependency
|
||||
chain as ``max_input_tokens``. Falls back to a conservative default when
|
||||
the value is unavailable.
|
||||
"""
|
||||
if max_input_tokens is None or max_input_tokens <= 0:
|
||||
return _MIN_TOOL_OUTPUT_CHARS # conservative fallback
|
||||
|
||||
budget = int(max_input_tokens * _CHARS_PER_TOKEN * _TOOL_OUTPUT_CONTEXT_FRACTION)
|
||||
return max(_MIN_TOOL_OUTPUT_CHARS, min(budget, _MAX_TOOL_OUTPUT_CHARS))
|
||||
|
||||
|
||||
def format_documents_for_context(
|
||||
documents: list[dict[str, Any]],
|
||||
*,
|
||||
max_chars: int = _MAX_TOOL_OUTPUT_CHARS,
|
||||
max_chunk_chars: int = _MAX_CHUNK_CHARS,
|
||||
) -> str:
|
||||
"""
|
||||
Format retrieved documents into a readable context string for the LLM.
|
||||
|
||||
Documents are added in order (highest relevance first) until the character
|
||||
budget is reached. Individual chunks are capped at ``max_chunk_chars`` so
|
||||
a single oversized chunk cannot monopolize the output.
|
||||
|
||||
Args:
|
||||
documents: List of document dictionaries from connector search
|
||||
max_chars: Approximate character budget for the entire output.
|
||||
max_chunk_chars: Per-chunk character cap (content is tail-truncated).
|
||||
|
||||
Returns:
|
||||
Formatted string with document contents and metadata
|
||||
|
|
@ -278,37 +318,57 @@ def format_documents_for_context(documents: list[dict[str, Any]]) -> str:
|
|||
"BAIDU_SEARCH_API",
|
||||
}
|
||||
|
||||
# Render XML expected by citation instructions
|
||||
# Render XML expected by citation instructions, respecting the char budget.
|
||||
parts: list[str] = []
|
||||
for g in grouped.values():
|
||||
total_chars = 0
|
||||
total_docs = len(grouped)
|
||||
|
||||
for doc_idx, g in enumerate(grouped.values()):
|
||||
metadata_json = json.dumps(g["metadata"], ensure_ascii=False)
|
||||
is_live_search = g["document_type"] in live_search_connectors
|
||||
|
||||
parts.append("<document>")
|
||||
parts.append("<document_metadata>")
|
||||
parts.append(f" <document_id>{g['document_id']}</document_id>")
|
||||
parts.append(f" <document_type>{g['document_type']}</document_type>")
|
||||
parts.append(f" <title><![CDATA[{g['title']}]]></title>")
|
||||
parts.append(f" <url><![CDATA[{g['url']}]]></url>")
|
||||
parts.append(f" <metadata_json><![CDATA[{metadata_json}]]></metadata_json>")
|
||||
parts.append("</document_metadata>")
|
||||
parts.append("")
|
||||
parts.append("<document_content>")
|
||||
doc_lines: list[str] = [
|
||||
"<document>",
|
||||
"<document_metadata>",
|
||||
f" <document_id>{g['document_id']}</document_id>",
|
||||
f" <document_type>{g['document_type']}</document_type>",
|
||||
f" <title><![CDATA[{g['title']}]]></title>",
|
||||
f" <url><![CDATA[{g['url']}]]></url>",
|
||||
f" <metadata_json><![CDATA[{metadata_json}]]></metadata_json>",
|
||||
"</document_metadata>",
|
||||
"",
|
||||
"<document_content>",
|
||||
]
|
||||
|
||||
for ch in g["chunks"]:
|
||||
ch_content = ch["content"]
|
||||
# For live search connectors, use the document URL as the chunk id
|
||||
# so the LLM outputs [citation:https://...] which the frontend
|
||||
# renders as a clickable link.
|
||||
if max_chunk_chars and len(ch_content) > max_chunk_chars:
|
||||
ch_content = ch_content[:max_chunk_chars] + "\n...(truncated)"
|
||||
ch_id = g["url"] if (is_live_search and g["url"]) else ch["chunk_id"]
|
||||
if ch_id is None:
|
||||
parts.append(f" <chunk><![CDATA[{ch_content}]]></chunk>")
|
||||
doc_lines.append(f" <chunk><![CDATA[{ch_content}]]></chunk>")
|
||||
else:
|
||||
parts.append(f" <chunk id='{ch_id}'><![CDATA[{ch_content}]]></chunk>")
|
||||
doc_lines.append(
|
||||
f" <chunk id='{ch_id}'><![CDATA[{ch_content}]]></chunk>"
|
||||
)
|
||||
|
||||
parts.append("</document_content>")
|
||||
parts.append("</document>")
|
||||
parts.append("")
|
||||
doc_lines.extend(["</document_content>", "</document>", ""])
|
||||
|
||||
doc_xml = "\n".join(doc_lines)
|
||||
doc_len = len(doc_xml)
|
||||
|
||||
# Always include at least the first document; afterwards enforce budget.
|
||||
if doc_idx > 0 and total_chars + doc_len > max_chars:
|
||||
remaining = total_docs - doc_idx
|
||||
parts.append(
|
||||
f"<!-- Output truncated: {remaining} more document(s) omitted "
|
||||
f"(budget {max_chars} chars). Refine your query or reduce top_k "
|
||||
f"to retrieve different results. -->"
|
||||
)
|
||||
break
|
||||
|
||||
parts.append(doc_xml)
|
||||
total_chars += doc_len
|
||||
|
||||
return "\n".join(parts).strip()
|
||||
|
||||
|
|
@ -328,6 +388,7 @@ async def search_knowledge_base_async(
|
|||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
available_connectors: list[str] | None = None,
|
||||
max_input_tokens: int | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Search the user's knowledge base for relevant documents.
|
||||
|
|
@ -345,6 +406,8 @@ async def search_knowledge_base_async(
|
|||
end_date: Optional end datetime (UTC) for filtering documents
|
||||
available_connectors: Optional list of connectors actually available in the search space.
|
||||
If provided, only these connectors will be searched.
|
||||
max_input_tokens: Model context window size (tokens). Used to dynamically
|
||||
size the output so it fits within the model's limits.
|
||||
|
||||
Returns:
|
||||
Formatted string with search results
|
||||
|
|
@ -488,7 +551,8 @@ async def search_knowledge_base_async(
|
|||
|
||||
deduplicated.append(doc)
|
||||
|
||||
return format_documents_for_context(deduplicated)
|
||||
output_budget = _compute_tool_output_budget(max_input_tokens)
|
||||
return format_documents_for_context(deduplicated, max_chars=output_budget)
|
||||
|
||||
|
||||
def _build_connector_docstring(available_connectors: list[str] | None) -> str:
|
||||
|
|
@ -552,6 +616,7 @@ def create_search_knowledge_base_tool(
|
|||
connector_service: ConnectorService,
|
||||
available_connectors: list[str] | None = None,
|
||||
available_document_types: list[str] | None = None,
|
||||
max_input_tokens: int | None = None,
|
||||
) -> StructuredTool:
|
||||
"""
|
||||
Factory function to create the search_knowledge_base tool with injected dependencies.
|
||||
|
|
@ -564,6 +629,8 @@ def create_search_knowledge_base_tool(
|
|||
Used to dynamically generate the tool docstring.
|
||||
available_document_types: Optional list of document types that have data in the search space.
|
||||
Used to inform the LLM about what data exists.
|
||||
max_input_tokens: Model context window (tokens) from litellm model info.
|
||||
Used to dynamically size tool output.
|
||||
|
||||
Returns:
|
||||
A configured StructuredTool instance
|
||||
|
|
@ -634,6 +701,7 @@ NOTE: `WEBCRAWLER_CONNECTOR` is mapped internally to the canonical document type
|
|||
start_date=parsed_start,
|
||||
end_date=parsed_end,
|
||||
available_connectors=_available_connectors,
|
||||
max_input_tokens=max_input_tokens,
|
||||
)
|
||||
|
||||
# Create StructuredTool with dynamic description
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ This implements real MCP protocol support similar to Cursor's implementation.
|
|||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import StructuredTool
|
||||
|
|
@ -25,6 +26,9 @@ from app.db import SearchSourceConnector, SearchSourceConnectorType
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MCP_CACHE_TTL_SECONDS = 300 # 5 minutes
|
||||
_mcp_tools_cache: dict[int, tuple[float, list[StructuredTool]]] = {}
|
||||
|
||||
|
||||
def _create_dynamic_input_model_from_schema(
|
||||
tool_name: str,
|
||||
|
|
@ -355,6 +359,19 @@ async def _load_http_mcp_tools(
|
|||
return tools
|
||||
|
||||
|
||||
def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None:
|
||||
"""Invalidate cached MCP tools.
|
||||
|
||||
Args:
|
||||
search_space_id: If provided, only invalidate for this search space.
|
||||
If None, invalidate all cached MCP tools.
|
||||
"""
|
||||
if search_space_id is not None:
|
||||
_mcp_tools_cache.pop(search_space_id, None)
|
||||
else:
|
||||
_mcp_tools_cache.clear()
|
||||
|
||||
|
||||
async def load_mcp_tools(
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
|
|
@ -364,6 +381,9 @@ async def load_mcp_tools(
|
|||
This discovers tools dynamically from MCP servers using the protocol.
|
||||
Supports both stdio (local process) and HTTP (remote server) transports.
|
||||
|
||||
Results are cached per search space for up to 5 minutes to avoid
|
||||
re-spawning MCP server processes on every chat message.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
search_space_id: User's search space ID
|
||||
|
|
@ -372,8 +392,20 @@ async def load_mcp_tools(
|
|||
List of LangChain StructuredTool instances
|
||||
|
||||
"""
|
||||
now = time.monotonic()
|
||||
cached = _mcp_tools_cache.get(search_space_id)
|
||||
if cached is not None:
|
||||
cached_at, cached_tools = cached
|
||||
if now - cached_at < _MCP_CACHE_TTL_SECONDS:
|
||||
logger.info(
|
||||
"Using cached MCP tools for search space %s (%d tools, age=%.0fs)",
|
||||
search_space_id,
|
||||
len(cached_tools),
|
||||
now - cached_at,
|
||||
)
|
||||
return list(cached_tools)
|
||||
|
||||
try:
|
||||
# Fetch all MCP connectors for this search space
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.connector_type
|
||||
|
|
@ -385,27 +417,22 @@ async def load_mcp_tools(
|
|||
tools: list[StructuredTool] = []
|
||||
for connector in result.scalars():
|
||||
try:
|
||||
# Early validation: Extract and validate connector config
|
||||
config = connector.config or {}
|
||||
server_config = config.get("server_config", {})
|
||||
|
||||
# Validate server_config exists and is a dict
|
||||
if not server_config or not isinstance(server_config, dict):
|
||||
logger.warning(
|
||||
f"MCP connector {connector.id} (name: '{connector.name}') has invalid or missing server_config, skipping"
|
||||
)
|
||||
continue
|
||||
|
||||
# Determine transport type
|
||||
transport = server_config.get("transport", "stdio")
|
||||
|
||||
if transport in ("streamable-http", "http", "sse"):
|
||||
# HTTP-based MCP server
|
||||
connector_tools = await _load_http_mcp_tools(
|
||||
connector.id, connector.name, server_config
|
||||
)
|
||||
else:
|
||||
# stdio-based MCP server (default)
|
||||
connector_tools = await _load_stdio_mcp_tools(
|
||||
connector.id, connector.name, server_config
|
||||
)
|
||||
|
|
@ -417,6 +444,7 @@ async def load_mcp_tools(
|
|||
f"Failed to load tools from MCP connector {connector.id}: {e!s}"
|
||||
)
|
||||
|
||||
_mcp_tools_cache[search_space_id] = (now, tools)
|
||||
logger.info(f"Loaded {len(tools)} MCP tools for search space {search_space_id}")
|
||||
return tools
|
||||
|
||||
|
|
|
|||
|
|
@ -47,6 +47,10 @@ from app.db import ChatVisibility
|
|||
|
||||
from .display_image import create_display_image_tool
|
||||
from .generate_image import create_generate_image_tool
|
||||
from .google_drive import (
|
||||
create_create_google_drive_file_tool,
|
||||
create_delete_google_drive_file_tool,
|
||||
)
|
||||
from .knowledge_base import create_search_knowledge_base_tool
|
||||
from .linear import (
|
||||
create_create_linear_issue_tool,
|
||||
|
|
@ -114,6 +118,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
# Optional: dynamically discovered connectors/document types
|
||||
available_connectors=deps.get("available_connectors"),
|
||||
available_document_types=deps.get("available_document_types"),
|
||||
max_input_tokens=deps.get("max_input_tokens"),
|
||||
),
|
||||
requires=["search_space_id", "db_session", "connector_service"],
|
||||
# Note: available_connectors and available_document_types are optional
|
||||
|
|
@ -292,6 +297,29 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
# =========================================================================
|
||||
# GOOGLE DRIVE TOOLS - create files, delete files
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="create_google_drive_file",
|
||||
description="Create a new Google Doc or Google Sheet in Google Drive",
|
||||
factory=lambda deps: create_create_google_drive_file_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_google_drive_file",
|
||||
description="Move an indexed Google Drive file to trash",
|
||||
factory=lambda deps: create_delete_google_drive_file_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -417,8 +445,18 @@ async def build_tools_async(
|
|||
List of configured tool instances ready for the agent, including MCP tools.
|
||||
|
||||
"""
|
||||
# Build standard tools
|
||||
import time
|
||||
|
||||
_perf_log = logging.getLogger("surfsense.perf")
|
||||
_perf_log.setLevel(logging.DEBUG)
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
tools = build_tools(dependencies, enabled_tools, disabled_tools, additional_tools)
|
||||
_perf_log.info(
|
||||
"[build_tools_async] Built-in tools in %.3fs (%d tools)",
|
||||
time.perf_counter() - _t0,
|
||||
len(tools),
|
||||
)
|
||||
|
||||
# Load MCP tools if requested and dependencies are available
|
||||
if (
|
||||
|
|
@ -427,10 +465,16 @@ async def build_tools_async(
|
|||
and "search_space_id" in dependencies
|
||||
):
|
||||
try:
|
||||
_t0 = time.perf_counter()
|
||||
mcp_tools = await load_mcp_tools(
|
||||
dependencies["db_session"],
|
||||
dependencies["search_space_id"],
|
||||
)
|
||||
_perf_log.info(
|
||||
"[build_tools_async] MCP tools loaded in %.3fs (%d tools)",
|
||||
time.perf_counter() - _t0,
|
||||
len(mcp_tools),
|
||||
)
|
||||
tools.extend(mcp_tools)
|
||||
logging.info(
|
||||
f"Registered {len(mcp_tools)} MCP tools: {[t.name for t in mcp_tools]}",
|
||||
|
|
|
|||
|
|
@ -14,8 +14,8 @@ from langchain_core.tools import tool
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.db import SurfsenseDocsChunk, SurfsenseDocsDocument
|
||||
from app.utils.document_converters import embed_text
|
||||
|
||||
|
||||
def format_surfsense_docs_results(results: list[tuple]) -> str:
|
||||
|
|
@ -100,7 +100,7 @@ async def search_surfsense_docs_async(
|
|||
Formatted string with relevant documentation content
|
||||
"""
|
||||
# Get embedding for the query
|
||||
query_embedding = config.embedding_model_instance.embed(query)
|
||||
query_embedding = embed_text(query)
|
||||
|
||||
# Vector similarity search on chunks, joining with documents
|
||||
stmt = (
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ from langchain_core.tools import tool
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.db import MemoryCategory, SharedMemory, User
|
||||
from app.utils.document_converters import embed_text
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -64,7 +64,7 @@ async def save_shared_memory(
|
|||
count = await get_shared_memory_count(db_session, search_space_id)
|
||||
if count >= MAX_MEMORIES_PER_SEARCH_SPACE:
|
||||
await delete_oldest_shared_memory(db_session, search_space_id)
|
||||
embedding = config.embedding_model_instance.embed(content)
|
||||
embedding = embed_text(content)
|
||||
row = SharedMemory(
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=_to_uuid(created_by_id),
|
||||
|
|
@ -108,7 +108,7 @@ async def recall_shared_memory(
|
|||
if category and category in valid_categories:
|
||||
stmt = stmt.where(SharedMemory.category == MemoryCategory(category))
|
||||
if query:
|
||||
query_embedding = config.embedding_model_instance.embed(query)
|
||||
query_embedding = embed_text(query)
|
||||
stmt = stmt.order_by(
|
||||
SharedMemory.embedding.op("<=>")(query_embedding)
|
||||
).limit(top_k)
|
||||
|
|
|
|||
|
|
@ -17,8 +17,8 @@ from langchain_core.tools import tool
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.db import MemoryCategory, UserMemory
|
||||
from app.utils.document_converters import embed_text
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -178,7 +178,7 @@ def create_save_memory_tool(
|
|||
await delete_oldest_memory(db_session, user_id, search_space_id)
|
||||
|
||||
# Generate embedding for the memory
|
||||
embedding = config.embedding_model_instance.embed(content)
|
||||
embedding = embed_text(content)
|
||||
|
||||
# Create new memory using ORM
|
||||
# The pgvector Vector column type handles embedding conversion automatically
|
||||
|
|
@ -268,7 +268,7 @@ def create_recall_memory_tool(
|
|||
|
||||
if query:
|
||||
# Semantic search using embeddings
|
||||
query_embedding = config.embedding_model_instance.embed(query)
|
||||
query_embedding = embed_text(query)
|
||||
|
||||
# Build query with vector similarity
|
||||
stmt = (
|
||||
|
|
|
|||
|
|
@ -175,8 +175,39 @@ def rate_limit_password_reset(request: Request):
|
|||
)
|
||||
|
||||
|
||||
def _enable_slow_callback_logging(threshold_sec: float = 0.5) -> None:
|
||||
"""Monkey-patch the event loop to warn whenever a callback blocks longer than *threshold_sec*.
|
||||
|
||||
This helps pinpoint synchronous code that freezes the entire FastAPI server.
|
||||
Only active when the PERF_DEBUG env var is set (to avoid overhead in production).
|
||||
"""
|
||||
import os
|
||||
|
||||
if not os.environ.get("PERF_DEBUG"):
|
||||
return
|
||||
|
||||
_slow_log = logging.getLogger("surfsense.perf.slow")
|
||||
_slow_log.setLevel(logging.WARNING)
|
||||
if not _slow_log.handlers:
|
||||
_h = logging.StreamHandler()
|
||||
_h.setFormatter(logging.Formatter("%(asctime)s [SLOW-CALLBACK] %(message)s"))
|
||||
_slow_log.addHandler(_h)
|
||||
_slow_log.propagate = False
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.slow_callback_duration = threshold_sec # type: ignore[attr-defined]
|
||||
loop.set_debug(True)
|
||||
_slow_log.warning(
|
||||
"Event-loop slow-callback detector ENABLED (threshold=%.1fs). "
|
||||
"Set PERF_DEBUG='' to disable.",
|
||||
threshold_sec,
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Enable slow-callback detection (set PERF_DEBUG=1 env var to activate)
|
||||
_enable_slow_callback_logging(threshold_sec=0.5)
|
||||
# Not needed if you setup a migration system like Alembic
|
||||
await create_db_and_tables()
|
||||
# Setup LangGraph checkpointer tables for conversation persistence
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.composio_connector import ComposioConnector
|
||||
from app.db import Document, DocumentStatus, DocumentType
|
||||
from app.services.composio_service import TOOLKIT_TO_DOCUMENT_TYPE
|
||||
|
|
@ -27,6 +26,7 @@ from app.tasks.connector_indexers.base import (
|
|||
)
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
|
|
@ -383,6 +383,7 @@ async def _process_gmail_messages_phase2(
|
|||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
enable_summary: bool = False,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
|
|
@ -415,7 +416,7 @@ async def _process_gmail_messages_phase2(
|
|||
session, user_id, search_space_id
|
||||
)
|
||||
|
||||
if user_llm:
|
||||
if user_llm and enable_summary:
|
||||
document_metadata_for_summary = {
|
||||
"message_id": item["message_id"],
|
||||
"thread_id": item["thread_id"],
|
||||
|
|
@ -427,10 +428,8 @@ async def _process_gmail_messages_phase2(
|
|||
item["markdown_content"], user_llm, document_metadata_for_summary
|
||||
)
|
||||
else:
|
||||
summary_content = f"Gmail: {item['subject']}\n\nFrom: {item['sender']}\nDate: {item['date_str']}"
|
||||
summary_embedding = config.embedding_model_instance.embed(
|
||||
summary_content
|
||||
)
|
||||
summary_content = f"Gmail: {item['subject']}\n\nFrom: {item['sender']}\nDate: {item['date_str']}\n\n{item['markdown_content']}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(item["markdown_content"])
|
||||
|
||||
|
|
@ -646,6 +645,7 @@ async def index_composio_gmail(
|
|||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
enable_summary=getattr(connector, "enable_summary", False),
|
||||
on_heartbeat_callback=on_heartbeat_callback,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.composio_connector import ComposioConnector
|
||||
from app.db import Document, DocumentStatus, DocumentType
|
||||
from app.services.composio_service import TOOLKIT_TO_DOCUMENT_TYPE
|
||||
|
|
@ -27,6 +26,7 @@ from app.tasks.connector_indexers.base import (
|
|||
)
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
|
|
@ -440,7 +440,7 @@ async def index_composio_google_calendar(
|
|||
session, user_id, search_space_id
|
||||
)
|
||||
|
||||
if user_llm:
|
||||
if user_llm and connector.enable_summary:
|
||||
document_metadata_for_summary = {
|
||||
"event_id": item["event_id"],
|
||||
"summary": item["summary"],
|
||||
|
|
@ -456,12 +456,10 @@ async def index_composio_google_calendar(
|
|||
document_metadata_for_summary,
|
||||
)
|
||||
else:
|
||||
summary_content = f"Calendar: {item['summary']}\n\nStart: {item['start_time']}\nEnd: {item['end_time']}"
|
||||
if item["location"]:
|
||||
summary_content += f"\nLocation: {item['location']}"
|
||||
summary_embedding = config.embedding_model_instance.embed(
|
||||
summary_content
|
||||
summary_content = (
|
||||
f"Calendar: {item['summary']}\n\n{item['markdown_content']}"
|
||||
)
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(item["markdown_content"])
|
||||
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ from app.tasks.connector_indexers.base import (
|
|||
)
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
|
|
@ -714,6 +715,7 @@ async def index_composio_google_drive(
|
|||
max_items=max_items,
|
||||
task_logger=task_logger,
|
||||
log_entry=log_entry,
|
||||
enable_summary=getattr(connector, "enable_summary", False),
|
||||
on_heartbeat_callback=on_heartbeat_callback,
|
||||
)
|
||||
else:
|
||||
|
|
@ -747,6 +749,7 @@ async def index_composio_google_drive(
|
|||
max_items=max_items,
|
||||
task_logger=task_logger,
|
||||
log_entry=log_entry,
|
||||
enable_summary=getattr(connector, "enable_summary", False),
|
||||
on_heartbeat_callback=on_heartbeat_callback,
|
||||
)
|
||||
|
||||
|
|
@ -829,6 +832,7 @@ async def _index_composio_drive_delta_sync(
|
|||
max_items: int,
|
||||
task_logger: TaskLoggingService,
|
||||
log_entry,
|
||||
enable_summary: bool = False,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, int, list[str]]:
|
||||
"""Index Google Drive files using delta sync with real-time document status updates.
|
||||
|
|
@ -1079,7 +1083,7 @@ async def _index_composio_drive_delta_sync(
|
|||
session, user_id, search_space_id
|
||||
)
|
||||
|
||||
if user_llm:
|
||||
if user_llm and enable_summary:
|
||||
document_metadata_for_summary = {
|
||||
"file_id": item["file_id"],
|
||||
"file_name": item["file_name"],
|
||||
|
|
@ -1090,10 +1094,8 @@ async def _index_composio_drive_delta_sync(
|
|||
markdown_content, user_llm, document_metadata_for_summary
|
||||
)
|
||||
else:
|
||||
summary_content = f"Google Drive File: {item['file_name']}\n\nType: {item['mime_type']}"
|
||||
summary_embedding = config.embedding_model_instance.embed(
|
||||
summary_content
|
||||
)
|
||||
summary_content = f"Google Drive File: {item['file_name']}\n\nType: {item['mime_type']}\n\n{markdown_content}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(markdown_content)
|
||||
|
||||
|
|
@ -1155,6 +1157,7 @@ async def _index_composio_drive_full_scan(
|
|||
max_items: int,
|
||||
task_logger: TaskLoggingService,
|
||||
log_entry,
|
||||
enable_summary: bool = False,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, int, list[str]]:
|
||||
"""Index Google Drive files using full scan with real-time document status updates."""
|
||||
|
|
@ -1488,7 +1491,7 @@ async def _index_composio_drive_full_scan(
|
|||
session, user_id, search_space_id
|
||||
)
|
||||
|
||||
if user_llm:
|
||||
if user_llm and enable_summary:
|
||||
document_metadata_for_summary = {
|
||||
"file_id": item["file_id"],
|
||||
"file_name": item["file_name"],
|
||||
|
|
@ -1499,10 +1502,8 @@ async def _index_composio_drive_full_scan(
|
|||
markdown_content, user_llm, document_metadata_for_summary
|
||||
)
|
||||
else:
|
||||
summary_content = f"Google Drive File: {item['file_name']}\n\nType: {item['mime_type']}"
|
||||
summary_embedding = config.embedding_model_instance.embed(
|
||||
summary_content
|
||||
)
|
||||
summary_content = f"Google Drive File: {item['file_name']}\n\nType: {item['mime_type']}\n\n{markdown_content}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(markdown_content)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,15 @@
|
|||
"""Google Drive API client."""
|
||||
|
||||
import io
|
||||
from typing import Any
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.errors import HttpError
|
||||
from googleapiclient.http import MediaIoBaseUpload
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from .credentials import get_valid_credentials
|
||||
from .file_types import GOOGLE_DOC, GOOGLE_SHEET
|
||||
|
||||
|
||||
class GoogleDriveClient:
|
||||
|
|
@ -179,3 +182,65 @@ class GoogleDriveClient:
|
|||
return None, f"HTTP error exporting file: {e.resp.status}"
|
||||
except Exception as e:
|
||||
return None, f"Error exporting file: {e!s}"
|
||||
|
||||
async def create_file(
|
||||
self,
|
||||
name: str,
|
||||
mime_type: str,
|
||||
parent_folder_id: str | None = None,
|
||||
content: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
service = await self.get_service()
|
||||
|
||||
body: dict[str, Any] = {"name": name, "mimeType": mime_type}
|
||||
if parent_folder_id:
|
||||
body["parents"] = [parent_folder_id]
|
||||
|
||||
media: MediaIoBaseUpload | None = None
|
||||
if content:
|
||||
if mime_type == GOOGLE_DOC:
|
||||
import markdown as md_lib
|
||||
|
||||
html = md_lib.markdown(content)
|
||||
media = MediaIoBaseUpload(
|
||||
io.BytesIO(html.encode("utf-8")),
|
||||
mimetype="text/html",
|
||||
resumable=False,
|
||||
)
|
||||
elif mime_type == GOOGLE_SHEET:
|
||||
media = MediaIoBaseUpload(
|
||||
io.BytesIO(content.encode("utf-8")),
|
||||
mimetype="text/csv",
|
||||
resumable=False,
|
||||
)
|
||||
|
||||
if media:
|
||||
return (
|
||||
service.files()
|
||||
.create(
|
||||
body=body,
|
||||
media_body=media,
|
||||
fields="id,name,mimeType,webViewLink",
|
||||
supportsAllDrives=True,
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
|
||||
return (
|
||||
service.files()
|
||||
.create(
|
||||
body=body,
|
||||
fields="id,name,mimeType,webViewLink",
|
||||
supportsAllDrives=True,
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
|
||||
async def trash_file(self, file_id: str) -> bool:
|
||||
service = await self.get_service()
|
||||
service.files().update(
|
||||
fileId=file_id,
|
||||
body={"trashed": True},
|
||||
supportsAllDrives=True,
|
||||
).execute()
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from collections.abc import AsyncGenerator
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi_users.db import SQLAlchemyBaseUserTableUUID, SQLAlchemyUserDatabase
|
||||
|
|
@ -31,7 +31,7 @@ if config.AUTH_TYPE == "GOOGLE":
|
|||
DATABASE_URL = config.DATABASE_URL
|
||||
|
||||
|
||||
class DocumentType(str, Enum):
|
||||
class DocumentType(StrEnum):
|
||||
EXTENSION = "EXTENSION"
|
||||
CRAWLED_URL = "CRAWLED_URL"
|
||||
FILE = "FILE"
|
||||
|
|
@ -60,7 +60,7 @@ class DocumentType(str, Enum):
|
|||
COMPOSIO_GOOGLE_CALENDAR_CONNECTOR = "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR"
|
||||
|
||||
|
||||
class SearchSourceConnectorType(str, Enum):
|
||||
class SearchSourceConnectorType(StrEnum):
|
||||
SERPER_API = "SERPER_API" # NOT IMPLEMENTED YET : DON'T REMEMBER WHY : MOST PROBABLY BECAUSE WE NEED TO CRAWL THE RESULTS RETURNED BY IT
|
||||
TAVILY_API = "TAVILY_API"
|
||||
SEARXNG_API = "SEARXNG_API"
|
||||
|
|
@ -93,7 +93,7 @@ class SearchSourceConnectorType(str, Enum):
|
|||
COMPOSIO_GOOGLE_CALENDAR_CONNECTOR = "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR"
|
||||
|
||||
|
||||
class PodcastStatus(str, Enum):
|
||||
class PodcastStatus(StrEnum):
|
||||
PENDING = "pending"
|
||||
GENERATING = "generating"
|
||||
READY = "ready"
|
||||
|
|
@ -177,7 +177,7 @@ class DocumentStatus:
|
|||
return None
|
||||
|
||||
|
||||
class LiteLLMProvider(str, Enum):
|
||||
class LiteLLMProvider(StrEnum):
|
||||
"""
|
||||
Enum for LLM providers supported by LiteLLM.
|
||||
"""
|
||||
|
|
@ -215,7 +215,7 @@ class LiteLLMProvider(str, Enum):
|
|||
CUSTOM = "CUSTOM"
|
||||
|
||||
|
||||
class ImageGenProvider(str, Enum):
|
||||
class ImageGenProvider(StrEnum):
|
||||
"""
|
||||
Enum for image generation providers supported by LiteLLM.
|
||||
This is a subset of LLM providers — only those that support image generation.
|
||||
|
|
@ -233,7 +233,7 @@ class ImageGenProvider(str, Enum):
|
|||
NSCALE = "NSCALE"
|
||||
|
||||
|
||||
class LogLevel(str, Enum):
|
||||
class LogLevel(StrEnum):
|
||||
DEBUG = "DEBUG"
|
||||
INFO = "INFO"
|
||||
WARNING = "WARNING"
|
||||
|
|
@ -241,13 +241,13 @@ class LogLevel(str, Enum):
|
|||
CRITICAL = "CRITICAL"
|
||||
|
||||
|
||||
class LogStatus(str, Enum):
|
||||
class LogStatus(StrEnum):
|
||||
IN_PROGRESS = "IN_PROGRESS"
|
||||
SUCCESS = "SUCCESS"
|
||||
FAILED = "FAILED"
|
||||
|
||||
|
||||
class IncentiveTaskType(str, Enum):
|
||||
class IncentiveTaskType(StrEnum):
|
||||
"""
|
||||
Enum for incentive task types that users can complete to earn free pages.
|
||||
Each task can only be completed once per user.
|
||||
|
|
@ -298,7 +298,7 @@ INCENTIVE_TASKS_CONFIG = {
|
|||
}
|
||||
|
||||
|
||||
class Permission(str, Enum):
|
||||
class Permission(StrEnum):
|
||||
"""
|
||||
Granular permissions for search space resources.
|
||||
Use '*' (FULL_ACCESS) to grant all permissions.
|
||||
|
|
@ -471,7 +471,7 @@ class BaseModel(Base):
|
|||
id = Column(Integer, primary_key=True, index=True)
|
||||
|
||||
|
||||
class NewChatMessageRole(str, Enum):
|
||||
class NewChatMessageRole(StrEnum):
|
||||
"""Role enum for new chat messages."""
|
||||
|
||||
USER = "user"
|
||||
|
|
@ -479,7 +479,7 @@ class NewChatMessageRole(str, Enum):
|
|||
SYSTEM = "system"
|
||||
|
||||
|
||||
class ChatVisibility(str, Enum):
|
||||
class ChatVisibility(StrEnum):
|
||||
"""
|
||||
Visibility/sharing level for chat threads.
|
||||
|
||||
|
|
@ -788,7 +788,7 @@ class ChatSessionState(BaseModel):
|
|||
ai_responding_to_user = relationship("User")
|
||||
|
||||
|
||||
class MemoryCategory(str, Enum):
|
||||
class MemoryCategory(StrEnum):
|
||||
"""Categories for user memories."""
|
||||
|
||||
# Using lowercase keys to match PostgreSQL enum values
|
||||
|
|
@ -1317,6 +1317,12 @@ class SearchSourceConnector(BaseModel, TimestampMixin):
|
|||
last_indexed_at = Column(TIMESTAMP(timezone=True), nullable=True)
|
||||
config = Column(JSON, nullable=False)
|
||||
|
||||
# Summary generation (LLM-based) - disabled by default to save resources.
|
||||
# When enabled, improves hybrid search quality at the cost of LLM calls.
|
||||
enable_summary = Column(
|
||||
Boolean, nullable=False, default=False, server_default="false"
|
||||
)
|
||||
|
||||
# Periodic indexing fields
|
||||
periodic_indexing_enabled = Column(Boolean, nullable=False, default=False)
|
||||
indexing_frequency_minutes = Column(Integer, nullable=True)
|
||||
|
|
|
|||
0
surfsense_backend/app/indexing_pipeline/__init__.py
Normal file
0
surfsense_backend/app/indexing_pipeline/__init__.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import DocumentStatus, DocumentType
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
|
||||
|
||||
async def index_uploaded_file(
|
||||
markdown_content: str,
|
||||
filename: str,
|
||||
etl_service: str,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
session: AsyncSession,
|
||||
llm,
|
||||
should_summarize: bool = False,
|
||||
) -> None:
|
||||
connector_doc = ConnectorDocument(
|
||||
title=filename,
|
||||
source_markdown=markdown_content,
|
||||
unique_id=filename,
|
||||
document_type=DocumentType.FILE,
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=user_id,
|
||||
connector_id=None,
|
||||
should_summarize=should_summarize,
|
||||
should_use_code_chunker=False,
|
||||
fallback_summary=markdown_content[:4000],
|
||||
metadata={
|
||||
"FILE_NAME": filename,
|
||||
"ETL_SERVICE": etl_service,
|
||||
},
|
||||
)
|
||||
|
||||
service = IndexingPipelineService(session)
|
||||
documents = await service.prepare_for_indexing([connector_doc])
|
||||
|
||||
if not documents:
|
||||
raise RuntimeError("prepare_for_indexing returned no documents")
|
||||
|
||||
indexed = await service.index(documents[0], connector_doc, llm)
|
||||
|
||||
if not DocumentStatus.is_state(indexed.status, DocumentStatus.READY):
|
||||
raise RuntimeError(indexed.status.get("reason", "Indexing failed"))
|
||||
|
||||
indexed.content_needs_reindexing = False
|
||||
await session.commit()
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.db import DocumentType
|
||||
|
||||
|
||||
class ConnectorDocument(BaseModel):
|
||||
"""Canonical data transfer object produced by connector adapters and consumed by the indexing pipeline."""
|
||||
|
||||
title: str
|
||||
source_markdown: str
|
||||
unique_id: str
|
||||
document_type: DocumentType
|
||||
search_space_id: int = Field(gt=0)
|
||||
should_summarize: bool = True
|
||||
should_use_code_chunker: bool = False
|
||||
fallback_summary: str | None = None
|
||||
metadata: dict = {}
|
||||
connector_id: int | None = None
|
||||
created_by_id: str
|
||||
|
||||
@field_validator("title", "source_markdown", "unique_id", "created_by_id")
|
||||
@classmethod
|
||||
def not_empty(cls, v: str, info) -> str:
|
||||
if not v.strip():
|
||||
raise ValueError(f"{info.field_name} must not be empty or whitespace")
|
||||
return v
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
from app.config import config
|
||||
|
||||
|
||||
def chunk_text(text: str, use_code_chunker: bool = False) -> list[str]:
|
||||
"""Chunk a text string using the configured chunker and return the chunk texts."""
|
||||
chunker = (
|
||||
config.code_chunker_instance if use_code_chunker else config.chunker_instance
|
||||
)
|
||||
return [c.text for c in chunker.chunk(text)]
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
from app.utils.document_converters import embed_text
|
||||
|
||||
__all__ = ["embed_text"]
|
||||
15
surfsense_backend/app/indexing_pipeline/document_hashing.py
Normal file
15
surfsense_backend/app/indexing_pipeline/document_hashing.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
import hashlib
|
||||
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
|
||||
|
||||
def compute_unique_identifier_hash(doc: ConnectorDocument) -> str:
|
||||
"""Return a stable SHA-256 hash identifying a document by its source identity."""
|
||||
combined = f"{doc.document_type.value}:{doc.unique_id}:{doc.search_space_id}"
|
||||
return hashlib.sha256(combined.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def compute_content_hash(doc: ConnectorDocument) -> str:
|
||||
"""Return a SHA-256 hash of the document's content scoped to its search space."""
|
||||
combined = f"{doc.search_space_id}:{doc.source_markdown}"
|
||||
return hashlib.sha256(combined.encode("utf-8")).hexdigest()
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import object_session
|
||||
from sqlalchemy.orm.attributes import set_committed_value
|
||||
|
||||
from app.db import Document, DocumentStatus
|
||||
|
||||
|
||||
async def rollback_and_persist_failure(
|
||||
session: AsyncSession, document: Document, message: str
|
||||
) -> None:
|
||||
"""Roll back the current transaction and best-effort persist a failed status.
|
||||
|
||||
Called exclusively from except blocks — must never raise, or the new exception
|
||||
would chain with the original and mask it entirely.
|
||||
"""
|
||||
try:
|
||||
await session.rollback()
|
||||
except Exception:
|
||||
return # Session is completely dead; nothing further we can do.
|
||||
try:
|
||||
await session.refresh(document)
|
||||
document.updated_at = datetime.now(UTC)
|
||||
document.status = DocumentStatus.failed(message)
|
||||
await session.commit()
|
||||
except Exception:
|
||||
pass # Best-effort; document will be retried on the next sync.
|
||||
|
||||
|
||||
def attach_chunks_to_document(document: Document, chunks: list) -> None:
|
||||
"""Assign chunks to a document without triggering SQLAlchemy async lazy loading."""
|
||||
set_committed_value(document, "chunks", chunks)
|
||||
session = object_session(document)
|
||||
if session is not None:
|
||||
if document.id is not None:
|
||||
for chunk in chunks:
|
||||
chunk.document_id = document.id
|
||||
session.add_all(chunks)
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
from app.prompts import SUMMARY_PROMPT_TEMPLATE
|
||||
from app.utils.document_converters import optimize_content_for_context_window
|
||||
|
||||
|
||||
async def summarize_document(
|
||||
source_markdown: str, llm, metadata: dict | None = None
|
||||
) -> str:
|
||||
"""Generate a text summary of a document using an LLM, prefixed with metadata when provided."""
|
||||
model_name = getattr(llm, "model", "gpt-3.5-turbo")
|
||||
optimized_content = optimize_content_for_context_window(
|
||||
source_markdown, metadata, model_name
|
||||
)
|
||||
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | llm
|
||||
content_with_metadata = (
|
||||
f"<DOCUMENT><DOCUMENT_METADATA>\n\n{metadata}\n\n</DOCUMENT_METADATA>"
|
||||
f"\n\n<DOCUMENT_CONTENT>\n\n{optimized_content}\n\n</DOCUMENT_CONTENT></DOCUMENT>"
|
||||
)
|
||||
summary_result = await summary_chain.ainvoke({"document": content_with_metadata})
|
||||
summary_content = summary_result.content
|
||||
|
||||
if metadata:
|
||||
metadata_parts = ["# DOCUMENT METADATA"]
|
||||
for key, value in metadata.items():
|
||||
if value:
|
||||
metadata_parts.append(f"**{key.replace('_', ' ').title()}:** {value}")
|
||||
metadata_section = "\n".join(metadata_parts)
|
||||
return f"{metadata_section}\n\n# DOCUMENT SUMMARY\n\n{summary_content}"
|
||||
|
||||
return summary_content
|
||||
146
surfsense_backend/app/indexing_pipeline/exceptions.py
Normal file
146
surfsense_backend/app/indexing_pipeline/exceptions.py
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
from litellm.exceptions import (
|
||||
APIConnectionError,
|
||||
APIResponseValidationError,
|
||||
AuthenticationError,
|
||||
BadGatewayError,
|
||||
BadRequestError,
|
||||
InternalServerError,
|
||||
NotFoundError,
|
||||
PermissionDeniedError,
|
||||
RateLimitError,
|
||||
ServiceUnavailableError,
|
||||
Timeout,
|
||||
UnprocessableEntityError,
|
||||
)
|
||||
from sqlalchemy.exc import IntegrityError as IntegrityError
|
||||
|
||||
# Tuples for use directly in except clauses.
|
||||
RETRYABLE_LLM_ERRORS = (
|
||||
RateLimitError,
|
||||
Timeout,
|
||||
ServiceUnavailableError,
|
||||
BadGatewayError,
|
||||
InternalServerError,
|
||||
APIConnectionError,
|
||||
)
|
||||
|
||||
PERMANENT_LLM_ERRORS = (
|
||||
AuthenticationError,
|
||||
PermissionDeniedError,
|
||||
NotFoundError,
|
||||
BadRequestError,
|
||||
UnprocessableEntityError,
|
||||
APIResponseValidationError,
|
||||
)
|
||||
|
||||
# (LiteLLMEmbeddings, CohereEmbeddings, GeminiEmbeddings all normalize to RuntimeError).
|
||||
EMBEDDING_ERRORS = (
|
||||
RuntimeError, # local device failure or API backend normalization
|
||||
OSError, # model files missing or corrupted (local backends)
|
||||
MemoryError, # document too large for available RAM
|
||||
OSError, # model files missing or corrupted (local backends)
|
||||
MemoryError, # document too large for available RAM
|
||||
)
|
||||
|
||||
|
||||
class PipelineMessages:
|
||||
RATE_LIMIT = "LLM rate limit exceeded. Will retry on next sync."
|
||||
LLM_TIMEOUT = "LLM request timed out. Will retry on next sync."
|
||||
LLM_UNAVAILABLE = "LLM service temporarily unavailable. Will retry on next sync."
|
||||
LLM_BAD_GATEWAY = "LLM gateway error. Will retry on next sync."
|
||||
LLM_SERVER_ERROR = "LLM internal server error. Will retry on next sync."
|
||||
LLM_CONNECTION = "Could not reach the LLM service. Check network connectivity."
|
||||
RATE_LIMIT = "LLM rate limit exceeded. Will retry on next sync."
|
||||
LLM_TIMEOUT = "LLM request timed out. Will retry on next sync."
|
||||
LLM_UNAVAILABLE = "LLM service temporarily unavailable. Will retry on next sync."
|
||||
LLM_BAD_GATEWAY = "LLM gateway error. Will retry on next sync."
|
||||
LLM_SERVER_ERROR = "LLM internal server error. Will retry on next sync."
|
||||
LLM_CONNECTION = "Could not reach the LLM service. Check network connectivity."
|
||||
|
||||
LLM_AUTH = "LLM authentication failed. Check your API key."
|
||||
LLM_PERMISSION = "LLM request denied. Check your account permissions."
|
||||
LLM_NOT_FOUND = "LLM model not found. Check your model configuration."
|
||||
LLM_BAD_REQUEST = "LLM rejected the request. Document content may be invalid."
|
||||
LLM_UNPROCESSABLE = (
|
||||
"Document exceeds the LLM context window even after optimization."
|
||||
)
|
||||
LLM_RESPONSE = "LLM returned an invalid response."
|
||||
LLM_AUTH = "LLM authentication failed. Check your API key."
|
||||
LLM_PERMISSION = "LLM request denied. Check your account permissions."
|
||||
LLM_NOT_FOUND = "LLM model not found. Check your model configuration."
|
||||
LLM_BAD_REQUEST = "LLM rejected the request. Document content may be invalid."
|
||||
LLM_UNPROCESSABLE = (
|
||||
"Document exceeds the LLM context window even after optimization."
|
||||
)
|
||||
LLM_RESPONSE = "LLM returned an invalid response."
|
||||
|
||||
EMBEDDING_FAILED = (
|
||||
"Embedding failed. Check your embedding model configuration or service."
|
||||
)
|
||||
EMBEDDING_MODEL = "Embedding model files are missing or corrupted."
|
||||
EMBEDDING_MEMORY = "Not enough memory to embed this document."
|
||||
EMBEDDING_FAILED = (
|
||||
"Embedding failed. Check your embedding model configuration or service."
|
||||
)
|
||||
EMBEDDING_MODEL = "Embedding model files are missing or corrupted."
|
||||
EMBEDDING_MEMORY = "Not enough memory to embed this document."
|
||||
|
||||
CHUNKING_OVERFLOW = "Document structure is too deeply nested to chunk."
|
||||
|
||||
|
||||
def safe_exception_message(exc: Exception) -> str:
|
||||
try:
|
||||
return str(exc)
|
||||
except Exception:
|
||||
return "Something went wrong during indexing. Error details could not be retrieved."
|
||||
|
||||
|
||||
def llm_retryable_message(exc: Exception) -> str:
|
||||
try:
|
||||
if isinstance(exc, RateLimitError):
|
||||
return PipelineMessages.RATE_LIMIT
|
||||
if isinstance(exc, Timeout):
|
||||
return PipelineMessages.LLM_TIMEOUT
|
||||
if isinstance(exc, ServiceUnavailableError):
|
||||
return PipelineMessages.LLM_UNAVAILABLE
|
||||
if isinstance(exc, BadGatewayError):
|
||||
return PipelineMessages.LLM_BAD_GATEWAY
|
||||
if isinstance(exc, InternalServerError):
|
||||
return PipelineMessages.LLM_SERVER_ERROR
|
||||
if isinstance(exc, APIConnectionError):
|
||||
return PipelineMessages.LLM_CONNECTION
|
||||
return safe_exception_message(exc)
|
||||
except Exception:
|
||||
return "Something went wrong when calling the LLM."
|
||||
|
||||
|
||||
def llm_permanent_message(exc: Exception) -> str:
|
||||
try:
|
||||
if isinstance(exc, AuthenticationError):
|
||||
return PipelineMessages.LLM_AUTH
|
||||
if isinstance(exc, PermissionDeniedError):
|
||||
return PipelineMessages.LLM_PERMISSION
|
||||
if isinstance(exc, NotFoundError):
|
||||
return PipelineMessages.LLM_NOT_FOUND
|
||||
if isinstance(exc, BadRequestError):
|
||||
return PipelineMessages.LLM_BAD_REQUEST
|
||||
if isinstance(exc, UnprocessableEntityError):
|
||||
return PipelineMessages.LLM_UNPROCESSABLE
|
||||
if isinstance(exc, APIResponseValidationError):
|
||||
return PipelineMessages.LLM_RESPONSE
|
||||
return safe_exception_message(exc)
|
||||
except Exception:
|
||||
return "Something went wrong when calling the LLM."
|
||||
|
||||
|
||||
def embedding_message(exc: Exception) -> str:
|
||||
try:
|
||||
if isinstance(exc, RuntimeError):
|
||||
return PipelineMessages.EMBEDDING_FAILED
|
||||
if isinstance(exc, OSError):
|
||||
return PipelineMessages.EMBEDDING_MODEL
|
||||
if isinstance(exc, MemoryError):
|
||||
return PipelineMessages.EMBEDDING_MEMORY
|
||||
return safe_exception_message(exc)
|
||||
except Exception:
|
||||
return "Something went wrong when generating the embedding."
|
||||
|
|
@ -0,0 +1,237 @@
|
|||
import contextlib
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import Chunk, Document, DocumentStatus
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
from app.indexing_pipeline.document_chunker import chunk_text
|
||||
from app.indexing_pipeline.document_embedder import embed_text
|
||||
from app.indexing_pipeline.document_hashing import (
|
||||
compute_content_hash,
|
||||
compute_unique_identifier_hash,
|
||||
)
|
||||
from app.indexing_pipeline.document_persistence import (
|
||||
attach_chunks_to_document,
|
||||
rollback_and_persist_failure,
|
||||
)
|
||||
from app.indexing_pipeline.document_summarizer import summarize_document
|
||||
from app.indexing_pipeline.exceptions import (
|
||||
EMBEDDING_ERRORS,
|
||||
PERMANENT_LLM_ERRORS,
|
||||
RETRYABLE_LLM_ERRORS,
|
||||
PipelineMessages,
|
||||
embedding_message,
|
||||
llm_permanent_message,
|
||||
llm_retryable_message,
|
||||
safe_exception_message,
|
||||
)
|
||||
from app.indexing_pipeline.pipeline_logger import (
|
||||
PipelineLogContext,
|
||||
log_batch_aborted,
|
||||
log_chunking_overflow,
|
||||
log_doc_skipped_unknown,
|
||||
log_document_queued,
|
||||
log_document_requeued,
|
||||
log_document_updated,
|
||||
log_embedding_error,
|
||||
log_index_started,
|
||||
log_index_success,
|
||||
log_permanent_llm_error,
|
||||
log_race_condition,
|
||||
log_retryable_llm_error,
|
||||
log_unexpected_error,
|
||||
)
|
||||
|
||||
|
||||
class IndexingPipelineService:
|
||||
"""Single pipeline for indexing connector documents. All connectors use this service."""
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self.session = session
|
||||
|
||||
async def prepare_for_indexing(
|
||||
self, connector_docs: list[ConnectorDocument]
|
||||
) -> list[Document]:
|
||||
"""
|
||||
Persist new documents and detect changes, returning only those that need indexing.
|
||||
"""
|
||||
documents = []
|
||||
seen_hashes: set[str] = set()
|
||||
batch_ctx = PipelineLogContext(
|
||||
connector_id=connector_docs[0].connector_id if connector_docs else 0,
|
||||
search_space_id=connector_docs[0].search_space_id if connector_docs else 0,
|
||||
unique_id="batch",
|
||||
)
|
||||
|
||||
for connector_doc in connector_docs:
|
||||
ctx = PipelineLogContext(
|
||||
connector_id=connector_doc.connector_id,
|
||||
search_space_id=connector_doc.search_space_id,
|
||||
unique_id=connector_doc.unique_id,
|
||||
)
|
||||
try:
|
||||
unique_identifier_hash = compute_unique_identifier_hash(connector_doc)
|
||||
content_hash = compute_content_hash(connector_doc)
|
||||
|
||||
if unique_identifier_hash in seen_hashes:
|
||||
continue
|
||||
seen_hashes.add(unique_identifier_hash)
|
||||
|
||||
result = await self.session.execute(
|
||||
select(Document).filter(
|
||||
Document.unique_identifier_hash == unique_identifier_hash
|
||||
)
|
||||
)
|
||||
existing = result.scalars().first()
|
||||
|
||||
if existing is not None:
|
||||
if existing.content_hash == content_hash:
|
||||
if existing.title != connector_doc.title:
|
||||
existing.title = connector_doc.title
|
||||
existing.updated_at = datetime.now(UTC)
|
||||
if not DocumentStatus.is_state(
|
||||
existing.status, DocumentStatus.READY
|
||||
):
|
||||
existing.status = DocumentStatus.pending()
|
||||
existing.updated_at = datetime.now(UTC)
|
||||
documents.append(existing)
|
||||
log_document_requeued(ctx)
|
||||
continue
|
||||
|
||||
existing.title = connector_doc.title
|
||||
existing.content_hash = content_hash
|
||||
existing.source_markdown = connector_doc.source_markdown
|
||||
existing.document_metadata = connector_doc.metadata
|
||||
existing.updated_at = datetime.now(UTC)
|
||||
existing.status = DocumentStatus.pending()
|
||||
documents.append(existing)
|
||||
log_document_updated(ctx)
|
||||
continue
|
||||
|
||||
duplicate = await self.session.execute(
|
||||
select(Document).filter(Document.content_hash == content_hash)
|
||||
)
|
||||
if duplicate.scalars().first() is not None:
|
||||
continue
|
||||
|
||||
document = Document(
|
||||
title=connector_doc.title,
|
||||
document_type=connector_doc.document_type,
|
||||
content="Pending...",
|
||||
content_hash=content_hash,
|
||||
unique_identifier_hash=unique_identifier_hash,
|
||||
source_markdown=connector_doc.source_markdown,
|
||||
document_metadata=connector_doc.metadata,
|
||||
search_space_id=connector_doc.search_space_id,
|
||||
connector_id=connector_doc.connector_id,
|
||||
created_by_id=connector_doc.created_by_id,
|
||||
updated_at=datetime.now(UTC),
|
||||
status=DocumentStatus.pending(),
|
||||
)
|
||||
self.session.add(document)
|
||||
documents.append(document)
|
||||
log_document_queued(ctx)
|
||||
|
||||
except Exception as e:
|
||||
log_doc_skipped_unknown(ctx, e)
|
||||
|
||||
try:
|
||||
await self.session.commit()
|
||||
return documents
|
||||
except IntegrityError:
|
||||
# A concurrent worker committed a document with the same content_hash
|
||||
# or unique_identifier_hash between our check and our INSERT.
|
||||
# The document already exists — roll back and let the next sync run handle it.
|
||||
log_race_condition(batch_ctx)
|
||||
await self.session.rollback()
|
||||
return []
|
||||
except Exception as e:
|
||||
log_batch_aborted(batch_ctx, e)
|
||||
await self.session.rollback()
|
||||
return []
|
||||
|
||||
async def index(
|
||||
self, document: Document, connector_doc: ConnectorDocument, llm
|
||||
) -> Document:
|
||||
"""
|
||||
Run summarization, embedding, and chunking for a document and persist the results.
|
||||
"""
|
||||
ctx = PipelineLogContext(
|
||||
connector_id=connector_doc.connector_id,
|
||||
search_space_id=connector_doc.search_space_id,
|
||||
unique_id=connector_doc.unique_id,
|
||||
doc_id=document.id,
|
||||
)
|
||||
try:
|
||||
log_index_started(ctx)
|
||||
document.status = DocumentStatus.processing()
|
||||
await self.session.commit()
|
||||
|
||||
if connector_doc.should_summarize and llm is not None:
|
||||
content = await summarize_document(
|
||||
connector_doc.source_markdown, llm, connector_doc.metadata
|
||||
)
|
||||
elif connector_doc.should_summarize and connector_doc.fallback_summary:
|
||||
content = connector_doc.fallback_summary
|
||||
else:
|
||||
content = connector_doc.source_markdown
|
||||
|
||||
embedding = embed_text(content)
|
||||
|
||||
await self.session.execute(
|
||||
delete(Chunk).where(Chunk.document_id == document.id)
|
||||
)
|
||||
|
||||
chunks = [
|
||||
Chunk(content=text, embedding=embed_text(text))
|
||||
for text in chunk_text(
|
||||
connector_doc.source_markdown,
|
||||
use_code_chunker=connector_doc.should_use_code_chunker,
|
||||
)
|
||||
]
|
||||
|
||||
document.content = content
|
||||
document.embedding = embedding
|
||||
attach_chunks_to_document(document, chunks)
|
||||
document.updated_at = datetime.now(UTC)
|
||||
document.status = DocumentStatus.ready()
|
||||
await self.session.commit()
|
||||
log_index_success(ctx, chunk_count=len(chunks))
|
||||
|
||||
except RETRYABLE_LLM_ERRORS as e:
|
||||
log_retryable_llm_error(ctx, e)
|
||||
await rollback_and_persist_failure(
|
||||
self.session, document, llm_retryable_message(e)
|
||||
)
|
||||
|
||||
except PERMANENT_LLM_ERRORS as e:
|
||||
log_permanent_llm_error(ctx, e)
|
||||
await rollback_and_persist_failure(
|
||||
self.session, document, llm_permanent_message(e)
|
||||
)
|
||||
|
||||
except RecursionError as e:
|
||||
log_chunking_overflow(ctx, e)
|
||||
await rollback_and_persist_failure(
|
||||
self.session, document, PipelineMessages.CHUNKING_OVERFLOW
|
||||
)
|
||||
|
||||
except EMBEDDING_ERRORS as e:
|
||||
log_embedding_error(ctx, e)
|
||||
await rollback_and_persist_failure(
|
||||
self.session, document, embedding_message(e)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
log_unexpected_error(ctx, e)
|
||||
await rollback_and_persist_failure(
|
||||
self.session, document, safe_exception_message(e)
|
||||
)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
await self.session.refresh(document)
|
||||
|
||||
return document
|
||||
126
surfsense_backend/app/indexing_pipeline/pipeline_logger.py
Normal file
126
surfsense_backend/app/indexing_pipeline/pipeline_logger.py
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineLogContext:
|
||||
connector_id: int | None
|
||||
search_space_id: int
|
||||
unique_id: str # always available from ConnectorDocument
|
||||
doc_id: int | None = None # set once the DB row exists (index phase only)
|
||||
|
||||
|
||||
class LogMessages:
|
||||
# prepare_for_indexing
|
||||
DOCUMENT_QUEUED = "New document queued for indexing."
|
||||
DOCUMENT_UPDATED = "Document content changed, re-queued for indexing."
|
||||
DOCUMENT_REQUEUED = "Stuck document re-queued for indexing."
|
||||
DOC_SKIPPED_UNKNOWN = "Unexpected error — document skipped."
|
||||
BATCH_ABORTED = "Fatal DB error — aborting prepare batch."
|
||||
RACE_CONDITION = "Concurrent worker beat us to the commit — rolling back batch."
|
||||
|
||||
# index
|
||||
INDEX_STARTED = "Document indexing started."
|
||||
INDEX_SUCCESS = "Document indexed successfully."
|
||||
LLM_RETRYABLE = (
|
||||
"Retryable LLM error — document marked failed, will retry on next sync."
|
||||
)
|
||||
LLM_PERMANENT = "Permanent LLM error — document marked failed."
|
||||
EMBEDDING_FAILED = "Embedding error — document marked failed."
|
||||
CHUNKING_OVERFLOW = "Chunking overflow — document marked failed."
|
||||
UNEXPECTED = "Unexpected error — document marked failed."
|
||||
|
||||
|
||||
def _format_context(ctx: PipelineLogContext) -> str:
|
||||
parts = [
|
||||
f"connector_id={ctx.connector_id}",
|
||||
f"search_space_id={ctx.search_space_id}",
|
||||
f"unique_id={ctx.unique_id}",
|
||||
]
|
||||
if ctx.doc_id is not None:
|
||||
parts.append(f"doc_id={ctx.doc_id}")
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
def _build_message(msg: str, ctx: PipelineLogContext, **extra) -> str:
|
||||
try:
|
||||
parts = [msg, _format_context(ctx)]
|
||||
for key, val in extra.items():
|
||||
parts.append(f"{key}={val}")
|
||||
return " ".join(parts)
|
||||
except Exception:
|
||||
return msg
|
||||
|
||||
|
||||
def _safe_log(
|
||||
level_fn, msg: str, ctx: PipelineLogContext, exc_info=None, **extra
|
||||
) -> None:
|
||||
# Logging must never raise — a broken log call inside an except block would
|
||||
# chain with the original exception and mask it entirely.
|
||||
try:
|
||||
message = _build_message(msg, ctx, **extra)
|
||||
level_fn(message, exc_info=exc_info)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# ── prepare_for_indexing ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def log_document_queued(ctx: PipelineLogContext) -> None:
|
||||
_safe_log(logger.info, LogMessages.DOCUMENT_QUEUED, ctx)
|
||||
|
||||
|
||||
def log_document_updated(ctx: PipelineLogContext) -> None:
|
||||
_safe_log(logger.info, LogMessages.DOCUMENT_UPDATED, ctx)
|
||||
|
||||
|
||||
def log_document_requeued(ctx: PipelineLogContext) -> None:
|
||||
_safe_log(logger.info, LogMessages.DOCUMENT_REQUEUED, ctx)
|
||||
|
||||
|
||||
def log_doc_skipped_unknown(ctx: PipelineLogContext, exc: Exception) -> None:
|
||||
_safe_log(
|
||||
logger.warning, LogMessages.DOC_SKIPPED_UNKNOWN, ctx, exc_info=exc, error=exc
|
||||
)
|
||||
|
||||
|
||||
def log_race_condition(ctx: PipelineLogContext) -> None:
|
||||
_safe_log(logger.warning, LogMessages.RACE_CONDITION, ctx)
|
||||
|
||||
|
||||
def log_batch_aborted(ctx: PipelineLogContext, exc: Exception) -> None:
|
||||
_safe_log(logger.error, LogMessages.BATCH_ABORTED, ctx, exc_info=exc, error=exc)
|
||||
|
||||
|
||||
# ── index ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def log_index_started(ctx: PipelineLogContext) -> None:
|
||||
_safe_log(logger.info, LogMessages.INDEX_STARTED, ctx)
|
||||
|
||||
|
||||
def log_index_success(ctx: PipelineLogContext, chunk_count: int) -> None:
|
||||
_safe_log(logger.info, LogMessages.INDEX_SUCCESS, ctx, chunk_count=chunk_count)
|
||||
|
||||
|
||||
def log_retryable_llm_error(ctx: PipelineLogContext, exc: Exception) -> None:
|
||||
_safe_log(logger.warning, LogMessages.LLM_RETRYABLE, ctx, exc_info=exc, error=exc)
|
||||
|
||||
|
||||
def log_permanent_llm_error(ctx: PipelineLogContext, exc: Exception) -> None:
|
||||
_safe_log(logger.error, LogMessages.LLM_PERMANENT, ctx, exc_info=exc, error=exc)
|
||||
|
||||
|
||||
def log_embedding_error(ctx: PipelineLogContext, exc: Exception) -> None:
|
||||
_safe_log(logger.error, LogMessages.EMBEDDING_FAILED, ctx, exc_info=exc, error=exc)
|
||||
|
||||
|
||||
def log_chunking_overflow(ctx: PipelineLogContext, exc: Exception) -> None:
|
||||
_safe_log(logger.error, LogMessages.CHUNKING_OVERFLOW, ctx, exc_info=exc, error=exc)
|
||||
|
||||
|
||||
def log_unexpected_error(ctx: PipelineLogContext, exc: Exception) -> None:
|
||||
_safe_log(logger.error, LogMessages.UNEXPECTED, ctx, exc_info=exc, error=exc)
|
||||
|
|
@ -36,6 +36,7 @@ from .podcasts_routes import router as podcasts_router
|
|||
from .public_chat_routes import router as public_chat_router
|
||||
from .rbac_routes import router as rbac_router
|
||||
from .reports_routes import router as reports_router
|
||||
from .sandbox_routes import router as sandbox_router
|
||||
from .search_source_connectors_routes import router as search_source_connectors_router
|
||||
from .search_spaces_routes import router as search_spaces_router
|
||||
from .slack_add_connector_route import router as slack_add_connector_router
|
||||
|
|
@ -50,6 +51,7 @@ router.include_router(editor_router)
|
|||
router.include_router(documents_router)
|
||||
router.include_router(notes_router)
|
||||
router.include_router(new_chat_router) # Chat with assistant-ui persistence
|
||||
router.include_router(sandbox_router) # Sandbox file downloads (Daytona)
|
||||
router.include_router(chat_comments_router)
|
||||
router.include_router(podcasts_router) # Podcast task status and audio
|
||||
router.include_router(reports_router) # Report CRUD and export (PDF/DOCX)
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ from app.schemas import (
|
|||
DocumentWithChunksRead,
|
||||
PaginatedResponse,
|
||||
)
|
||||
from app.services.task_dispatcher import TaskDispatcher, get_task_dispatcher
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
|
|
@ -44,6 +45,10 @@ os.environ["UNSTRUCTURED_HAS_PATCHED_LOOP"] = "1"
|
|||
|
||||
router = APIRouter()
|
||||
|
||||
MAX_FILES_PER_UPLOAD = 10
|
||||
MAX_FILE_SIZE_BYTES = 50 * 1024 * 1024 # 50 MB per file
|
||||
MAX_TOTAL_SIZE_BYTES = 200 * 1024 * 1024 # 200 MB total
|
||||
|
||||
|
||||
@router.post("/documents")
|
||||
async def create_documents(
|
||||
|
|
@ -114,8 +119,10 @@ async def create_documents(
|
|||
async def create_documents_file_upload(
|
||||
files: list[UploadFile],
|
||||
search_space_id: int = Form(...),
|
||||
should_summarize: bool = Form(False),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
dispatcher: TaskDispatcher = Depends(get_task_dispatcher),
|
||||
):
|
||||
"""
|
||||
Upload files as documents with real-time status tracking.
|
||||
|
|
@ -148,12 +155,37 @@ async def create_documents_file_upload(
|
|||
if not files:
|
||||
raise HTTPException(status_code=400, detail="No files provided")
|
||||
|
||||
if len(files) > MAX_FILES_PER_UPLOAD:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"Too many files. Maximum {MAX_FILES_PER_UPLOAD} files per upload.",
|
||||
)
|
||||
|
||||
total_size = 0
|
||||
for file in files:
|
||||
file_size = file.size or 0
|
||||
if file_size > MAX_FILE_SIZE_BYTES:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"File '{file.filename}' ({file_size / (1024 * 1024):.1f} MB) "
|
||||
f"exceeds the {MAX_FILE_SIZE_BYTES // (1024 * 1024)} MB per-file limit.",
|
||||
)
|
||||
total_size += file_size
|
||||
|
||||
if total_size > MAX_TOTAL_SIZE_BYTES:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"Total upload size ({total_size / (1024 * 1024):.1f} MB) "
|
||||
f"exceeds the {MAX_TOTAL_SIZE_BYTES // (1024 * 1024)} MB limit.",
|
||||
)
|
||||
|
||||
created_documents: list[Document] = []
|
||||
files_to_process: list[
|
||||
tuple[Document, str, str]
|
||||
] = [] # (document, temp_path, filename)
|
||||
skipped_duplicates = 0
|
||||
duplicate_document_ids: list[int] = []
|
||||
actual_total_size = 0
|
||||
|
||||
# ===== PHASE 1: Create pending documents for all files =====
|
||||
# This makes ALL documents visible in the UI immediately with pending status
|
||||
|
|
@ -169,11 +201,28 @@ async def create_documents_file_upload(
|
|||
temp_path = temp_file.name
|
||||
|
||||
content = await file.read()
|
||||
file_size = len(content)
|
||||
|
||||
if file_size > MAX_FILE_SIZE_BYTES:
|
||||
os.unlink(temp_path)
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"File '{file.filename}' ({file_size / (1024 * 1024):.1f} MB) "
|
||||
f"exceeds the {MAX_FILE_SIZE_BYTES // (1024 * 1024)} MB per-file limit.",
|
||||
)
|
||||
|
||||
actual_total_size += file_size
|
||||
if actual_total_size > MAX_TOTAL_SIZE_BYTES:
|
||||
os.unlink(temp_path)
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"Total upload size ({actual_total_size / (1024 * 1024):.1f} MB) "
|
||||
f"exceeds the {MAX_TOTAL_SIZE_BYTES // (1024 * 1024)} MB limit.",
|
||||
)
|
||||
|
||||
with open(temp_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
file_size = len(content)
|
||||
|
||||
# Generate unique identifier for deduplication check
|
||||
unique_identifier_hash = generate_unique_identifier_hash(
|
||||
DocumentType.FILE, file.filename or "unknown", search_space_id
|
||||
|
|
@ -244,19 +293,16 @@ async def create_documents_file_upload(
|
|||
for doc in created_documents:
|
||||
await session.refresh(doc)
|
||||
|
||||
# ===== PHASE 2: Dispatch Celery tasks for each file =====
|
||||
# ===== PHASE 2: Dispatch tasks for each file =====
|
||||
# Each task will update document status: pending → processing → ready/failed
|
||||
from app.tasks.celery_tasks.document_tasks import (
|
||||
process_file_upload_with_document_task,
|
||||
)
|
||||
|
||||
for document, temp_path, filename in files_to_process:
|
||||
process_file_upload_with_document_task.delay(
|
||||
await dispatcher.dispatch_file_processing(
|
||||
document_id=document.id,
|
||||
temp_path=temp_path,
|
||||
filename=filename,
|
||||
search_space_id=search_space_id,
|
||||
user_id=str(user.id),
|
||||
should_summarize=should_summarize,
|
||||
)
|
||||
|
||||
return {
|
||||
|
|
@ -373,10 +419,11 @@ async def read_documents(
|
|||
# Convert database objects to API-friendly format
|
||||
api_documents = []
|
||||
for doc in db_documents:
|
||||
# Get user name (display_name or email fallback)
|
||||
created_by_name = None
|
||||
created_by_email = None
|
||||
if doc.created_by:
|
||||
created_by_name = doc.created_by.display_name or doc.created_by.email
|
||||
created_by_name = doc.created_by.display_name
|
||||
created_by_email = doc.created_by.email
|
||||
|
||||
# Parse status from JSONB
|
||||
status_data = None
|
||||
|
|
@ -400,6 +447,7 @@ async def read_documents(
|
|||
search_space_id=doc.search_space_id,
|
||||
created_by_id=doc.created_by_id,
|
||||
created_by_name=created_by_name,
|
||||
created_by_email=created_by_email,
|
||||
status=status_data,
|
||||
)
|
||||
)
|
||||
|
|
@ -528,10 +576,11 @@ async def search_documents(
|
|||
# Convert database objects to API-friendly format
|
||||
api_documents = []
|
||||
for doc in db_documents:
|
||||
# Get user name (display_name or email fallback)
|
||||
created_by_name = None
|
||||
created_by_email = None
|
||||
if doc.created_by:
|
||||
created_by_name = doc.created_by.display_name or doc.created_by.email
|
||||
created_by_name = doc.created_by.display_name
|
||||
created_by_email = doc.created_by.email
|
||||
|
||||
# Parse status from JSONB
|
||||
status_data = None
|
||||
|
|
@ -555,6 +604,7 @@ async def search_documents(
|
|||
search_space_id=doc.search_space_id,
|
||||
created_by_id=doc.created_by_id,
|
||||
created_by_name=created_by_name,
|
||||
created_by_email=created_by_email,
|
||||
status=status_data,
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -76,9 +76,9 @@ def get_token_encryption() -> TokenEncryption:
|
|||
|
||||
# Google Drive OAuth scopes
|
||||
SCOPES = [
|
||||
"https://www.googleapis.com/auth/drive.readonly", # Read-only access to Drive
|
||||
"https://www.googleapis.com/auth/userinfo.email", # User email
|
||||
"https://www.googleapis.com/auth/userinfo.profile", # User profile
|
||||
"https://www.googleapis.com/auth/drive",
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
"openid",
|
||||
]
|
||||
|
||||
|
|
@ -151,6 +151,75 @@ async def connect_drive(space_id: int, user: User = Depends(current_active_user)
|
|||
) from e
|
||||
|
||||
|
||||
@router.get("/auth/google/drive/connector/reauth")
|
||||
async def reauth_drive(
|
||||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""
|
||||
Initiate Google Drive re-authentication to upgrade OAuth scopes.
|
||||
|
||||
Query params:
|
||||
space_id: Search space ID the connector belongs to
|
||||
connector_id: ID of the existing connector to re-authenticate
|
||||
|
||||
Returns:
|
||||
JSON with auth_url to redirect user to Google authorization
|
||||
"""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Google Drive connector not found or access denied",
|
||||
)
|
||||
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="SECRET_KEY not configured for OAuth security."
|
||||
)
|
||||
|
||||
flow = get_google_flow()
|
||||
|
||||
state_manager = get_state_manager()
|
||||
extra: dict = {"connector_id": connector_id}
|
||||
if return_url and return_url.startswith("/"):
|
||||
extra["return_url"] = return_url
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
|
||||
|
||||
auth_url, _ = flow.authorization_url(
|
||||
access_type="offline",
|
||||
prompt="consent",
|
||||
include_granted_scopes="true",
|
||||
state=state_encoded,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Initiating Google Drive re-auth for user {user.id}, connector {connector_id}"
|
||||
)
|
||||
return {"auth_url": auth_url}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initiate Google Drive re-auth: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to initiate Google re-auth: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/auth/google/drive/connector/callback")
|
||||
async def drive_callback(
|
||||
request: Request,
|
||||
|
|
@ -214,6 +283,8 @@ async def drive_callback(
|
|||
|
||||
user_id = UUID(data["user_id"])
|
||||
space_id = data["space_id"]
|
||||
reauth_connector_id = data.get("connector_id")
|
||||
reauth_return_url = data.get("return_url")
|
||||
|
||||
logger.info(
|
||||
f"Processing Google Drive callback for user {user_id}, space {space_id}"
|
||||
|
|
@ -253,7 +324,45 @@ async def drive_callback(
|
|||
# Mark that credentials are encrypted for backward compatibility
|
||||
creds_dict["_token_encrypted"] = True
|
||||
|
||||
# Check for duplicate connector (same account already connected)
|
||||
if reauth_connector_id:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == reauth_connector_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
db_connector = result.scalars().first()
|
||||
if not db_connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Connector not found or access denied during re-auth",
|
||||
)
|
||||
|
||||
existing_start_page_token = db_connector.config.get("start_page_token")
|
||||
db_connector.config = {
|
||||
**creds_dict,
|
||||
"start_page_token": existing_start_page_token,
|
||||
}
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
flag_modified(db_connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(db_connector)
|
||||
|
||||
logger.info(
|
||||
f"Re-authenticated Google Drive connector {db_connector.id} for user {user_id}"
|
||||
)
|
||||
if reauth_return_url and reauth_return_url.startswith("/"):
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=google-drive-connector&connectorId={db_connector.id}"
|
||||
)
|
||||
|
||||
is_duplicate = await check_duplicate_connector(
|
||||
session,
|
||||
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ These endpoints support the ThreadHistoryAdapter pattern from assistant-ui:
|
|||
- POST /threads/{thread_id}/messages - Append message
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
|
|
@ -52,9 +54,50 @@ from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
|
|||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
_background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _try_delete_sandbox(thread_id: int) -> None:
|
||||
"""Fire-and-forget sandbox + local file deletion so the HTTP response isn't blocked."""
|
||||
from app.agents.new_chat.sandbox import (
|
||||
delete_local_sandbox_files,
|
||||
delete_sandbox,
|
||||
is_sandbox_enabled,
|
||||
)
|
||||
|
||||
if not is_sandbox_enabled():
|
||||
return
|
||||
|
||||
async def _bg() -> None:
|
||||
try:
|
||||
await delete_sandbox(thread_id)
|
||||
except Exception:
|
||||
_logger.warning(
|
||||
"Background sandbox delete failed for thread %s",
|
||||
thread_id,
|
||||
exc_info=True,
|
||||
)
|
||||
try:
|
||||
delete_local_sandbox_files(thread_id)
|
||||
except Exception:
|
||||
_logger.warning(
|
||||
"Local sandbox file cleanup failed for thread %s",
|
||||
thread_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
task = loop.create_task(_bg())
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
|
||||
async def check_thread_access(
|
||||
session: AsyncSession,
|
||||
thread: NewChatThread,
|
||||
|
|
@ -648,6 +691,9 @@ async def delete_thread(
|
|||
|
||||
await session.delete(db_thread)
|
||||
await session.commit()
|
||||
|
||||
_try_delete_sandbox(thread_id)
|
||||
|
||||
return {"message": "Thread deleted successfully"}
|
||||
|
||||
except HTTPException:
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ import logging
|
|||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
|
||||
import pypandoc
|
||||
import typst
|
||||
|
|
@ -46,7 +46,7 @@ router = APIRouter()
|
|||
MAX_REPORT_LIST_LIMIT = 500
|
||||
|
||||
|
||||
class ExportFormat(str, Enum):
|
||||
class ExportFormat(StrEnum):
|
||||
PDF = "pdf"
|
||||
DOCX = "docx"
|
||||
|
||||
|
|
|
|||
105
surfsense_backend/app/routes/sandbox_routes.py
Normal file
105
surfsense_backend/app/routes/sandbox_routes.py
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
"""Routes for downloading files from Daytona sandbox environments."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import Response
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import NewChatThread, Permission, User, get_async_session
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
MIME_TYPES: dict[str, str] = {
|
||||
".png": "image/png",
|
||||
".jpg": "image/jpeg",
|
||||
".jpeg": "image/jpeg",
|
||||
".gif": "image/gif",
|
||||
".webp": "image/webp",
|
||||
".svg": "image/svg+xml",
|
||||
".pdf": "application/pdf",
|
||||
".csv": "text/csv",
|
||||
".json": "application/json",
|
||||
".txt": "text/plain",
|
||||
".html": "text/html",
|
||||
".md": "text/markdown",
|
||||
".py": "text/x-python",
|
||||
".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
".zip": "application/zip",
|
||||
}
|
||||
|
||||
|
||||
def _guess_media_type(filename: str) -> str:
|
||||
ext = ("." + filename.rsplit(".", 1)[-1].lower()) if "." in filename else ""
|
||||
return MIME_TYPES.get(ext, "application/octet-stream")
|
||||
|
||||
|
||||
@router.get("/threads/{thread_id}/sandbox/download")
|
||||
async def download_sandbox_file(
|
||||
thread_id: int,
|
||||
path: str = Query(..., description="Absolute path of the file inside the sandbox"),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Download a file from the Daytona sandbox associated with a chat thread."""
|
||||
|
||||
from app.agents.new_chat.sandbox import get_or_create_sandbox, is_sandbox_enabled
|
||||
|
||||
if not is_sandbox_enabled():
|
||||
raise HTTPException(status_code=404, detail="Sandbox is not enabled")
|
||||
|
||||
result = await session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == thread_id)
|
||||
)
|
||||
thread = result.scalars().first()
|
||||
if not thread:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
thread.search_space_id,
|
||||
Permission.CHATS_READ.value,
|
||||
"You don't have permission to access files in this thread",
|
||||
)
|
||||
|
||||
from app.agents.new_chat.sandbox import get_local_sandbox_file
|
||||
|
||||
# Prefer locally-persisted copy (sandbox may already be deleted)
|
||||
local_content = get_local_sandbox_file(thread_id, path)
|
||||
if local_content is not None:
|
||||
filename = path.rsplit("/", 1)[-1] if "/" in path else path
|
||||
media_type = _guess_media_type(filename)
|
||||
return Response(
|
||||
content=local_content,
|
||||
media_type=media_type,
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
|
||||
# Fall back to live sandbox download
|
||||
try:
|
||||
sandbox = await get_or_create_sandbox(thread_id)
|
||||
raw_sandbox = sandbox._sandbox
|
||||
content: bytes = await asyncio.to_thread(raw_sandbox.fs.download_file, path)
|
||||
except Exception as exc:
|
||||
logger.warning("Sandbox file download failed for %s: %s", path, exc)
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Could not download file: {exc}"
|
||||
) from exc
|
||||
|
||||
filename = path.rsplit("/", 1)[-1] if "/" in path else path
|
||||
media_type = _guess_media_type(filename)
|
||||
|
||||
return Response(
|
||||
content=content,
|
||||
media_type=media_type,
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
|
|
@ -2735,7 +2735,10 @@ async def create_mcp_connector(
|
|||
f"for user {user.id} in search space {search_space_id}"
|
||||
)
|
||||
|
||||
# Convert to read schema
|
||||
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
|
||||
|
||||
invalidate_mcp_tools_cache(search_space_id)
|
||||
|
||||
connector_read = SearchSourceConnectorRead.model_validate(db_connector)
|
||||
return MCPConnectorRead.from_connector(connector_read)
|
||||
|
||||
|
|
@ -2910,6 +2913,10 @@ async def update_mcp_connector(
|
|||
|
||||
logger.info(f"Updated MCP connector {connector_id}")
|
||||
|
||||
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
|
||||
|
||||
invalidate_mcp_tools_cache(connector.search_space_id)
|
||||
|
||||
connector_read = SearchSourceConnectorRead.model_validate(connector)
|
||||
return MCPConnectorRead.from_connector(connector_read)
|
||||
|
||||
|
|
@ -2960,9 +2967,14 @@ async def delete_mcp_connector(
|
|||
"You don't have permission to delete this connector",
|
||||
)
|
||||
|
||||
search_space_id = connector.search_space_id
|
||||
await session.delete(connector)
|
||||
await session.commit()
|
||||
|
||||
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
|
||||
|
||||
invalidate_mcp_tools_cache(search_space_id)
|
||||
|
||||
logger.info(f"Deleted MCP connector {connector_id}")
|
||||
|
||||
except HTTPException:
|
||||
|
|
|
|||
|
|
@ -60,9 +60,8 @@ class DocumentRead(BaseModel):
|
|||
updated_at: datetime | None
|
||||
search_space_id: int
|
||||
created_by_id: UUID | None = None # User who created/uploaded this document
|
||||
created_by_name: str | None = (
|
||||
None # Display name or email of the user who created this document
|
||||
)
|
||||
created_by_name: str | None = None
|
||||
created_by_email: str | None = None
|
||||
status: DocumentStatusSchema | None = (
|
||||
None # Processing status (ready, processing, failed)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
"""Podcast schemas for API responses."""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class PodcastStatusEnum(str, Enum):
|
||||
class PodcastStatusEnum(StrEnum):
|
||||
PENDING = "pending"
|
||||
GENERATING = "generating"
|
||||
READY = "ready"
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ class SearchSourceConnectorBase(BaseModel):
|
|||
is_indexable: bool
|
||||
last_indexed_at: datetime | None = None
|
||||
config: dict[str, Any]
|
||||
enable_summary: bool = False
|
||||
periodic_indexing_enabled: bool = False
|
||||
indexing_frequency_minutes: int | None = None
|
||||
next_scheduled_at: datetime | None = None
|
||||
|
|
@ -65,6 +66,7 @@ class SearchSourceConnectorUpdate(BaseModel):
|
|||
is_indexable: bool | None = None
|
||||
last_indexed_at: datetime | None = None
|
||||
config: dict[str, Any] | None = None
|
||||
enable_summary: bool | None = None
|
||||
periodic_indexing_enabled: bool | None = None
|
||||
indexing_frequency_minutes: int | None = None
|
||||
next_scheduled_at: datetime | None = None
|
||||
|
|
|
|||
|
|
@ -1303,10 +1303,9 @@ class ConnectorService:
|
|||
|
||||
sources_list = self._build_chunk_sources_from_documents(
|
||||
github_docs,
|
||||
description_fn=lambda chunk, _doc_info, metadata: metadata.get(
|
||||
"description"
|
||||
)
|
||||
or chunk.get("content", ""),
|
||||
description_fn=lambda chunk, _doc_info, metadata: (
|
||||
metadata.get("description") or chunk.get("content", "")
|
||||
),
|
||||
url_fn=lambda _doc_info, metadata: metadata.get("url", "") or "",
|
||||
)
|
||||
|
||||
|
|
|
|||
11
surfsense_backend/app/services/google_drive/__init__.py
Normal file
11
surfsense_backend/app/services/google_drive/__init__.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
from app.services.google_drive.tool_metadata_service import (
|
||||
GoogleDriveAccount,
|
||||
GoogleDriveFile,
|
||||
GoogleDriveToolMetadataService,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"GoogleDriveAccount",
|
||||
"GoogleDriveFile",
|
||||
"GoogleDriveToolMetadataService",
|
||||
]
|
||||
|
|
@ -0,0 +1,149 @@
|
|||
from dataclasses import dataclass
|
||||
|
||||
from sqlalchemy import and_, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import (
|
||||
Document,
|
||||
DocumentType,
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GoogleDriveAccount:
|
||||
id: int
|
||||
name: str
|
||||
|
||||
@classmethod
|
||||
def from_connector(cls, connector: SearchSourceConnector) -> "GoogleDriveAccount":
|
||||
return cls(id=connector.id, name=connector.name)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {"id": self.id, "name": self.name}
|
||||
|
||||
|
||||
@dataclass
|
||||
class GoogleDriveFile:
|
||||
file_id: str
|
||||
name: str
|
||||
mime_type: str
|
||||
web_view_link: str
|
||||
connector_id: int
|
||||
document_id: int
|
||||
|
||||
@classmethod
|
||||
def from_document(cls, document: Document) -> "GoogleDriveFile":
|
||||
meta = document.document_metadata or {}
|
||||
return cls(
|
||||
file_id=meta.get("google_drive_file_id", ""),
|
||||
name=meta.get("google_drive_file_name", document.title),
|
||||
mime_type=meta.get("google_drive_mime_type", ""),
|
||||
web_view_link=meta.get("web_view_link", ""),
|
||||
connector_id=document.connector_id,
|
||||
document_id=document.id,
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"file_id": self.file_id,
|
||||
"name": self.name,
|
||||
"mime_type": self.mime_type,
|
||||
"web_view_link": self.web_view_link,
|
||||
"connector_id": self.connector_id,
|
||||
"document_id": self.document_id,
|
||||
}
|
||||
|
||||
|
||||
class GoogleDriveToolMetadataService:
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self._db_session = db_session
|
||||
|
||||
async def get_creation_context(self, search_space_id: int, user_id: str) -> dict:
|
||||
accounts = await self._get_google_drive_accounts(search_space_id, user_id)
|
||||
|
||||
if not accounts:
|
||||
return {
|
||||
"accounts": [],
|
||||
"supported_types": [],
|
||||
"error": "No Google Drive account connected",
|
||||
}
|
||||
|
||||
return {
|
||||
"accounts": [acc.to_dict() for acc in accounts],
|
||||
"supported_types": ["google_doc", "google_sheet"],
|
||||
}
|
||||
|
||||
async def get_trash_context(
|
||||
self, search_space_id: int, user_id: str, file_name: str
|
||||
) -> dict:
|
||||
result = await self._db_session.execute(
|
||||
select(Document)
|
||||
.join(
|
||||
SearchSourceConnector, Document.connector_id == SearchSourceConnector.id
|
||||
)
|
||||
.filter(
|
||||
and_(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.document_type == DocumentType.GOOGLE_DRIVE_FILE,
|
||||
func.lower(Document.title) == func.lower(file_name),
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
document = result.scalars().first()
|
||||
|
||||
if not document:
|
||||
return {
|
||||
"error": (
|
||||
f"File '{file_name}' not found in your indexed Google Drive files. "
|
||||
"This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, "
|
||||
"or (3) the file name is different."
|
||||
)
|
||||
}
|
||||
|
||||
if not document.connector_id:
|
||||
return {"error": "Document has no associated connector"}
|
||||
|
||||
result = await self._db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
and_(
|
||||
SearchSourceConnector.id == document.connector_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
|
||||
if not connector:
|
||||
return {"error": "Connector not found or access denied"}
|
||||
|
||||
account = GoogleDriveAccount.from_connector(connector)
|
||||
file = GoogleDriveFile.from_document(document)
|
||||
|
||||
return {
|
||||
"account": account.to_dict(),
|
||||
"file": file.to_dict(),
|
||||
}
|
||||
|
||||
async def _get_google_drive_accounts(
|
||||
self, search_space_id: int, user_id: str
|
||||
) -> list[GoogleDriveAccount]:
|
||||
result = await self._db_session.execute(
|
||||
select(SearchSourceConnector)
|
||||
.filter(
|
||||
and_(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
.order_by(SearchSourceConnector.last_indexed_at.desc())
|
||||
)
|
||||
connectors = result.scalars().all()
|
||||
return [GoogleDriveAccount.from_connector(c) for c in connectors]
|
||||
|
|
@ -4,12 +4,12 @@ from datetime import datetime
|
|||
from sqlalchemy import delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.linear_connector import LinearConnector
|
||||
from app.db import Chunk, Document
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
)
|
||||
|
|
@ -80,7 +80,7 @@ class LinearKBSyncService:
|
|||
state = formatted_issue.get("state", "Unknown")
|
||||
priority = issue_raw.get("priorityLabel", "Unknown")
|
||||
comment_count = len(formatted_issue.get("comments", []))
|
||||
description = formatted_issue.get("description", "")
|
||||
formatted_issue.get("description", "")
|
||||
|
||||
user_llm = await get_user_long_context_llm(
|
||||
self.db_session, user_id, search_space_id, disable_streaming=True
|
||||
|
|
@ -100,18 +100,10 @@ class LinearKBSyncService:
|
|||
issue_content, user_llm, document_metadata_for_summary
|
||||
)
|
||||
else:
|
||||
if description and len(description) > 1000:
|
||||
description = description[:997] + "..."
|
||||
summary_content = (
|
||||
f"Linear Issue {issue_identifier}: {issue_title}\n\n"
|
||||
f"Status: {state}\n\n"
|
||||
)
|
||||
if description:
|
||||
summary_content += f"Description: {description}\n\n"
|
||||
summary_content += f"Comments: {comment_count}"
|
||||
summary_embedding = config.embedding_model_instance.embed(
|
||||
summary_content
|
||||
f"Linear Issue {issue_identifier}: {issue_title}\n\n{issue_content}"
|
||||
)
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
await self.db_session.execute(
|
||||
delete(Chunk).where(Chunk.document_id == document.id)
|
||||
|
|
|
|||
|
|
@ -12,16 +12,35 @@ synchronous ChatLiteLLM-like interface and async methods.
|
|||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.exceptions import ContextOverflowError
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from litellm import Router
|
||||
from litellm.exceptions import (
|
||||
BadRequestError as LiteLLMBadRequestError,
|
||||
ContextWindowExceededError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_CONTEXT_OVERFLOW_PATTERNS = re.compile(
|
||||
r"(input tokens exceed|context.{0,20}(length|window|limit)|"
|
||||
r"maximum context length|token.{0,20}(limit|exceed)|"
|
||||
r"too many tokens|reduce the length)",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def _is_context_overflow_error(exc: LiteLLMBadRequestError) -> bool:
|
||||
"""Check if a BadRequestError is actually a context window overflow."""
|
||||
return bool(_CONTEXT_OVERFLOW_PATTERNS.search(str(exc)))
|
||||
|
||||
|
||||
# Special ID for Auto mode - uses router for load balancing
|
||||
AUTO_MODE_ID = 0
|
||||
|
||||
|
|
@ -234,6 +253,10 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
|
||||
This wraps the LiteLLM Router to provide the same interface as ChatLiteLLM,
|
||||
making it a drop-in replacement for auto-mode routing.
|
||||
|
||||
Exposes a ``profile`` with ``max_input_tokens`` set to the smallest context
|
||||
window across all router deployments so that deepagents
|
||||
SummarizationMiddleware can use fraction-based triggers.
|
||||
"""
|
||||
|
||||
# Use model_config for Pydantic v2 compatibility
|
||||
|
|
@ -265,7 +288,6 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
"""
|
||||
try:
|
||||
super().__init__(**kwargs)
|
||||
# Store router and tools as private attributes
|
||||
resolved_router = router or LLMRouterService.get_router()
|
||||
object.__setattr__(self, "_router", resolved_router)
|
||||
object.__setattr__(self, "_bound_tools", bound_tools)
|
||||
|
|
@ -274,6 +296,12 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
raise ValueError(
|
||||
"LLM Router not initialized. Call LLMRouterService.initialize() first."
|
||||
)
|
||||
|
||||
# Set profile so deepagents SummarizationMiddleware gets fraction-based triggers
|
||||
computed_profile = self._compute_min_context_profile()
|
||||
if computed_profile is not None:
|
||||
object.__setattr__(self, "profile", computed_profile)
|
||||
|
||||
logger.info(
|
||||
f"ChatLiteLLMRouter initialized with {LLMRouterService.get_model_count()} models"
|
||||
)
|
||||
|
|
@ -281,6 +309,39 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
logger.error(f"Failed to initialize ChatLiteLLMRouter: {e}")
|
||||
raise
|
||||
|
||||
def _compute_min_context_profile(self) -> dict | None:
|
||||
"""Derive a profile dict with max_input_tokens from router deployments.
|
||||
|
||||
Uses litellm.get_model_info to look up each deployment's context window
|
||||
and picks the *minimum* so that summarization triggers before ANY model
|
||||
in the pool overflows.
|
||||
"""
|
||||
from litellm import get_model_info
|
||||
|
||||
if not self._router:
|
||||
return None
|
||||
|
||||
min_ctx: int | None = None
|
||||
for deployment in self._router.model_list:
|
||||
params = deployment.get("litellm_params", {})
|
||||
base_model = params.get("base_model") or params.get("model", "")
|
||||
try:
|
||||
info = get_model_info(base_model)
|
||||
ctx = info.get("max_input_tokens")
|
||||
if (
|
||||
isinstance(ctx, int)
|
||||
and ctx > 0
|
||||
and (min_ctx is None or ctx < min_ctx)
|
||||
):
|
||||
min_ctx = ctx
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if min_ctx is not None:
|
||||
logger.info(f"ChatLiteLLMRouter profile: max_input_tokens={min_ctx}")
|
||||
return {"max_input_tokens": min_ctx}
|
||||
return None
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "litellm-router"
|
||||
|
|
@ -359,13 +420,19 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
if self._tool_choice is not None:
|
||||
call_kwargs["tool_choice"] = self._tool_choice
|
||||
|
||||
# Call router completion
|
||||
response = self._router.completion(
|
||||
model=self.model,
|
||||
messages=formatted_messages,
|
||||
stop=stop,
|
||||
**call_kwargs,
|
||||
)
|
||||
try:
|
||||
response = self._router.completion(
|
||||
model=self.model,
|
||||
messages=formatted_messages,
|
||||
stop=stop,
|
||||
**call_kwargs,
|
||||
)
|
||||
except ContextWindowExceededError as e:
|
||||
raise ContextOverflowError(str(e)) from e
|
||||
except LiteLLMBadRequestError as e:
|
||||
if _is_context_overflow_error(e):
|
||||
raise ContextOverflowError(str(e)) from e
|
||||
raise
|
||||
|
||||
# Convert response to ChatResult with potential tool calls
|
||||
message = self._convert_response_to_message(response.choices[0].message)
|
||||
|
|
@ -396,13 +463,19 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
if self._tool_choice is not None:
|
||||
call_kwargs["tool_choice"] = self._tool_choice
|
||||
|
||||
# Call router async completion
|
||||
response = await self._router.acompletion(
|
||||
model=self.model,
|
||||
messages=formatted_messages,
|
||||
stop=stop,
|
||||
**call_kwargs,
|
||||
)
|
||||
try:
|
||||
response = await self._router.acompletion(
|
||||
model=self.model,
|
||||
messages=formatted_messages,
|
||||
stop=stop,
|
||||
**call_kwargs,
|
||||
)
|
||||
except ContextWindowExceededError as e:
|
||||
raise ContextOverflowError(str(e)) from e
|
||||
except LiteLLMBadRequestError as e:
|
||||
if _is_context_overflow_error(e):
|
||||
raise ContextOverflowError(str(e)) from e
|
||||
raise
|
||||
|
||||
# Convert response to ChatResult with potential tool calls
|
||||
message = self._convert_response_to_message(response.choices[0].message)
|
||||
|
|
@ -432,14 +505,20 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
if self._tool_choice is not None:
|
||||
call_kwargs["tool_choice"] = self._tool_choice
|
||||
|
||||
# Call router completion with streaming
|
||||
response = self._router.completion(
|
||||
model=self.model,
|
||||
messages=formatted_messages,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
**call_kwargs,
|
||||
)
|
||||
try:
|
||||
response = self._router.completion(
|
||||
model=self.model,
|
||||
messages=formatted_messages,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
**call_kwargs,
|
||||
)
|
||||
except ContextWindowExceededError as e:
|
||||
raise ContextOverflowError(str(e)) from e
|
||||
except LiteLLMBadRequestError as e:
|
||||
if _is_context_overflow_error(e):
|
||||
raise ContextOverflowError(str(e)) from e
|
||||
raise
|
||||
|
||||
# Yield chunks
|
||||
for chunk in response:
|
||||
|
|
@ -471,14 +550,20 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
if self._tool_choice is not None:
|
||||
call_kwargs["tool_choice"] = self._tool_choice
|
||||
|
||||
# Call router async completion with streaming
|
||||
response = await self._router.acompletion(
|
||||
model=self.model,
|
||||
messages=formatted_messages,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
**call_kwargs,
|
||||
)
|
||||
try:
|
||||
response = await self._router.acompletion(
|
||||
model=self.model,
|
||||
messages=formatted_messages,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
**call_kwargs,
|
||||
)
|
||||
except ContextWindowExceededError as e:
|
||||
raise ContextOverflowError(str(e)) from e
|
||||
except LiteLLMBadRequestError as e:
|
||||
if _is_context_overflow_error(e):
|
||||
raise ContextOverflowError(str(e)) from e
|
||||
raise
|
||||
|
||||
# Yield chunks asynchronously
|
||||
async for chunk in response:
|
||||
|
|
|
|||
|
|
@ -4,11 +4,11 @@ from datetime import datetime
|
|||
from sqlalchemy import delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.db import Chunk, Document
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
)
|
||||
|
|
@ -127,10 +127,8 @@ class NotionKBSyncService:
|
|||
logger.debug(f"Generated summary length: {len(summary_content)} chars")
|
||||
else:
|
||||
logger.warning("No LLM configured - using fallback summary")
|
||||
summary_content = f"Notion Page: {document.document_metadata.get('page_title')}\n\n{full_content[:500]}..."
|
||||
summary_embedding = config.embedding_model_instance.embed(
|
||||
summary_content
|
||||
)
|
||||
summary_content = f"Notion Page: {document.document_metadata.get('page_title')}\n\n{full_content}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
logger.debug(f"Deleting old chunks for document {document_id}")
|
||||
await self.db_session.execute(
|
||||
|
|
|
|||
53
surfsense_backend/app/services/task_dispatcher.py
Normal file
53
surfsense_backend/app/services/task_dispatcher.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
"""Task dispatcher abstraction for background document processing.
|
||||
|
||||
Decouples the upload endpoint from Celery so tests can swap in a
|
||||
synchronous (inline) implementation that needs only PostgreSQL.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class TaskDispatcher(Protocol):
|
||||
async def dispatch_file_processing(
|
||||
self,
|
||||
*,
|
||||
document_id: int,
|
||||
temp_path: str,
|
||||
filename: str,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
should_summarize: bool = False,
|
||||
) -> None: ...
|
||||
|
||||
|
||||
class CeleryTaskDispatcher:
|
||||
"""Production dispatcher — fires Celery tasks via Redis broker."""
|
||||
|
||||
async def dispatch_file_processing(
|
||||
self,
|
||||
*,
|
||||
document_id: int,
|
||||
temp_path: str,
|
||||
filename: str,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
should_summarize: bool = False,
|
||||
) -> None:
|
||||
from app.tasks.celery_tasks.document_tasks import (
|
||||
process_file_upload_with_document_task,
|
||||
)
|
||||
|
||||
process_file_upload_with_document_task.delay(
|
||||
document_id=document_id,
|
||||
temp_path=temp_path,
|
||||
filename=filename,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
should_summarize=should_summarize,
|
||||
)
|
||||
|
||||
|
||||
async def get_task_dispatcher() -> TaskDispatcher:
|
||||
return CeleryTaskDispatcher()
|
||||
|
|
@ -626,6 +626,7 @@ def process_file_upload_with_document_task(
|
|||
filename: str,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
should_summarize: bool = False,
|
||||
):
|
||||
"""
|
||||
Celery task to process uploaded file with existing pending document.
|
||||
|
|
@ -640,6 +641,7 @@ def process_file_upload_with_document_task(
|
|||
filename: Original filename
|
||||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
should_summarize: Whether to generate an LLM summary
|
||||
"""
|
||||
import traceback
|
||||
|
||||
|
|
@ -674,7 +676,12 @@ def process_file_upload_with_document_task(
|
|||
try:
|
||||
loop.run_until_complete(
|
||||
_process_file_with_document(
|
||||
document_id, temp_path, filename, search_space_id, user_id
|
||||
document_id,
|
||||
temp_path,
|
||||
filename,
|
||||
search_space_id,
|
||||
user_id,
|
||||
should_summarize=should_summarize,
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
|
|
@ -710,6 +717,7 @@ async def _process_file_with_document(
|
|||
filename: str,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
should_summarize: bool = False,
|
||||
):
|
||||
"""
|
||||
Process file and update existing pending document status.
|
||||
|
|
@ -811,6 +819,7 @@ async def _process_file_with_document(
|
|||
task_logger=task_logger,
|
||||
log_entry=log_entry,
|
||||
notification=notification,
|
||||
should_summarize=should_summarize,
|
||||
)
|
||||
|
||||
# Update notification on success
|
||||
|
|
|
|||
|
|
@ -9,17 +9,21 @@ Supports loading LLM configurations from:
|
|||
- NewLLMConfig database table (positive IDs for user-created configs with prompt settings)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import logging
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
|
||||
from app.agents.new_chat.checkpointer import get_checkpointer
|
||||
|
|
@ -30,7 +34,20 @@ from app.agents.new_chat.llm_config import (
|
|||
load_agent_config,
|
||||
load_llm_config_from_yaml,
|
||||
)
|
||||
from app.db import ChatVisibility, Document, Report, SurfsenseDocsDocument, async_session_maker
|
||||
from app.agents.new_chat.sandbox import (
|
||||
get_or_create_sandbox,
|
||||
is_sandbox_enabled,
|
||||
)
|
||||
from app.db import (
|
||||
ChatVisibility,
|
||||
Document,
|
||||
NewChatMessage,
|
||||
NewChatThread,
|
||||
Report,
|
||||
SearchSourceConnectorType,
|
||||
SurfsenseDocsDocument,
|
||||
async_session_maker,
|
||||
)
|
||||
from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
from app.services.chat_session_state_service import (
|
||||
clear_ai_responding,
|
||||
|
|
@ -40,6 +57,16 @@ from app.services.connector_service import ConnectorService
|
|||
from app.services.new_streaming_service import VercelStreamingService
|
||||
from app.utils.content_utils import bootstrap_history_from_db
|
||||
|
||||
_perf_log = logging.getLogger("surfsense.perf")
|
||||
_perf_log.setLevel(logging.DEBUG)
|
||||
if not _perf_log.handlers:
|
||||
_h = logging.StreamHandler()
|
||||
_h.setFormatter(logging.Formatter("%(asctime)s [PERF] %(message)s"))
|
||||
_perf_log.addHandler(_h)
|
||||
_perf_log.propagate = False
|
||||
|
||||
_background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
|
||||
def format_mentioned_documents_as_context(documents: list[Document]) -> str:
|
||||
"""
|
||||
|
|
@ -187,6 +214,7 @@ class StreamResult:
|
|||
accumulated_text: str = ""
|
||||
is_interrupted: bool = False
|
||||
interrupt_value: dict[str, Any] | None = None
|
||||
sandbox_files: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
async def _stream_agent_events(
|
||||
|
|
@ -404,6 +432,21 @@ async def _stream_agent_events(
|
|||
status="in_progress",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
elif tool_name == "execute":
|
||||
cmd = (
|
||||
tool_input.get("command", "")
|
||||
if isinstance(tool_input, dict)
|
||||
else str(tool_input)
|
||||
)
|
||||
display_cmd = cmd[:80] + ("…" if len(cmd) > 80 else "")
|
||||
last_active_step_title = "Running command"
|
||||
last_active_step_items = [f"$ {display_cmd}"]
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=tool_step_id,
|
||||
title="Running command",
|
||||
status="in_progress",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
else:
|
||||
last_active_step_title = f"Using {tool_name.replace('_', ' ')}"
|
||||
last_active_step_items = []
|
||||
|
|
@ -620,6 +663,32 @@ async def _stream_agent_events(
|
|||
status="completed",
|
||||
items=completed_items,
|
||||
)
|
||||
elif tool_name == "execute":
|
||||
raw_text = (
|
||||
tool_output.get("result", "")
|
||||
if isinstance(tool_output, dict)
|
||||
else str(tool_output)
|
||||
)
|
||||
m = re.match(r"^Exit code:\s*(\d+)", raw_text)
|
||||
exit_code_val = int(m.group(1)) if m else None
|
||||
if exit_code_val is not None and exit_code_val == 0:
|
||||
completed_items = [
|
||||
*last_active_step_items,
|
||||
"Completed successfully",
|
||||
]
|
||||
elif exit_code_val is not None:
|
||||
completed_items = [
|
||||
*last_active_step_items,
|
||||
f"Exit code: {exit_code_val}",
|
||||
]
|
||||
else:
|
||||
completed_items = [*last_active_step_items, "Finished"]
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=original_step_id,
|
||||
title="Running command",
|
||||
status="completed",
|
||||
items=completed_items,
|
||||
)
|
||||
elif tool_name == "ls":
|
||||
if isinstance(tool_output, dict):
|
||||
ls_output = tool_output.get("result", "")
|
||||
|
|
@ -804,6 +873,8 @@ async def _stream_agent_events(
|
|||
"create_linear_issue",
|
||||
"update_linear_issue",
|
||||
"delete_linear_issue",
|
||||
"create_google_drive_file",
|
||||
"delete_google_drive_file",
|
||||
):
|
||||
yield streaming_service.format_tool_output_available(
|
||||
tool_call_id,
|
||||
|
|
@ -811,6 +882,36 @@ async def _stream_agent_events(
|
|||
if isinstance(tool_output, dict)
|
||||
else {"result": tool_output},
|
||||
)
|
||||
elif tool_name == "execute":
|
||||
raw_text = (
|
||||
tool_output.get("result", "")
|
||||
if isinstance(tool_output, dict)
|
||||
else str(tool_output)
|
||||
)
|
||||
exit_code: int | None = None
|
||||
output_text = raw_text
|
||||
m = re.match(r"^Exit code:\s*(\d+)", raw_text)
|
||||
if m:
|
||||
exit_code = int(m.group(1))
|
||||
om = re.search(r"\nOutput:\n([\s\S]*)", raw_text)
|
||||
output_text = om.group(1) if om else ""
|
||||
thread_id_str = config.get("configurable", {}).get("thread_id", "")
|
||||
|
||||
for sf_match in re.finditer(
|
||||
r"^SANDBOX_FILE:\s*(.+)$", output_text, re.MULTILINE
|
||||
):
|
||||
fpath = sf_match.group(1).strip()
|
||||
if fpath and fpath not in result.sandbox_files:
|
||||
result.sandbox_files.append(fpath)
|
||||
|
||||
yield streaming_service.format_tool_output_available(
|
||||
tool_call_id,
|
||||
{
|
||||
"exit_code": exit_code,
|
||||
"output": output_text,
|
||||
"thread_id": thread_id_str,
|
||||
},
|
||||
)
|
||||
else:
|
||||
yield streaming_service.format_tool_output_available(
|
||||
tool_call_id,
|
||||
|
|
@ -879,6 +980,38 @@ async def _stream_agent_events(
|
|||
yield streaming_service.format_interrupt_request(result.interrupt_value)
|
||||
|
||||
|
||||
def _try_persist_and_delete_sandbox(
|
||||
thread_id: int,
|
||||
sandbox_files: list[str],
|
||||
) -> None:
|
||||
"""Fire-and-forget: persist sandbox files locally then delete the sandbox."""
|
||||
from app.agents.new_chat.sandbox import (
|
||||
is_sandbox_enabled,
|
||||
persist_and_delete_sandbox,
|
||||
)
|
||||
|
||||
if not is_sandbox_enabled():
|
||||
return
|
||||
|
||||
async def _run() -> None:
|
||||
try:
|
||||
await persist_and_delete_sandbox(thread_id, sandbox_files)
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
"persist_and_delete_sandbox failed for thread %s",
|
||||
thread_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
task = loop.create_task(_run())
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
|
||||
async def stream_new_chat(
|
||||
user_query: str,
|
||||
search_space_id: int,
|
||||
|
|
@ -915,6 +1048,8 @@ async def stream_new_chat(
|
|||
str: SSE formatted response strings
|
||||
"""
|
||||
streaming_service = VercelStreamingService()
|
||||
stream_result = StreamResult()
|
||||
_t_total = time.perf_counter()
|
||||
|
||||
try:
|
||||
# Mark AI as responding to this user for live collaboration
|
||||
|
|
@ -923,6 +1058,7 @@ async def stream_new_chat(
|
|||
# Load LLM config - supports both YAML (negative IDs) and database (positive IDs)
|
||||
agent_config: AgentConfig | None = None
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
if llm_config_id >= 0:
|
||||
# Positive ID: Load from NewLLMConfig database table
|
||||
agent_config = await load_agent_config(
|
||||
|
|
@ -953,6 +1089,11 @@ async def stream_new_chat(
|
|||
llm = create_chat_litellm_from_config(llm_config)
|
||||
# Create AgentConfig from YAML for consistency (uses defaults for prompt settings)
|
||||
agent_config = AgentConfig.from_yaml_config(llm_config)
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] LLM config loaded in %.3fs (config_id=%s)",
|
||||
time.perf_counter() - _t0,
|
||||
llm_config_id,
|
||||
)
|
||||
|
||||
if not llm:
|
||||
yield streaming_service.format_error("Failed to create LLM instance")
|
||||
|
|
@ -960,22 +1101,45 @@ async def stream_new_chat(
|
|||
return
|
||||
|
||||
# Create connector service
|
||||
_t0 = time.perf_counter()
|
||||
connector_service = ConnectorService(session, search_space_id=search_space_id)
|
||||
|
||||
# Get Firecrawl API key from webcrawler connector if configured
|
||||
from app.db import SearchSourceConnectorType
|
||||
|
||||
firecrawl_api_key = None
|
||||
webcrawler_connector = await connector_service.get_connector_by_type(
|
||||
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id
|
||||
)
|
||||
if webcrawler_connector and webcrawler_connector.config:
|
||||
firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY")
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] Connector service + firecrawl key in %.3fs",
|
||||
time.perf_counter() - _t0,
|
||||
)
|
||||
|
||||
# Get the PostgreSQL checkpointer for persistent conversation memory
|
||||
_t0 = time.perf_counter()
|
||||
checkpointer = await get_checkpointer()
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] Checkpointer ready in %.3fs", time.perf_counter() - _t0
|
||||
)
|
||||
|
||||
sandbox_backend = None
|
||||
_t0 = time.perf_counter()
|
||||
if is_sandbox_enabled():
|
||||
try:
|
||||
sandbox_backend = await get_or_create_sandbox(chat_id)
|
||||
except Exception as sandbox_err:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Sandbox creation failed, continuing without execute tool: %s",
|
||||
sandbox_err,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] Sandbox provisioning in %.3fs (enabled=%s)",
|
||||
time.perf_counter() - _t0,
|
||||
sandbox_backend is not None,
|
||||
)
|
||||
|
||||
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||
_t0 = time.perf_counter()
|
||||
agent = await create_surfsense_deep_agent(
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
|
|
@ -987,20 +1151,22 @@ async def stream_new_chat(
|
|||
agent_config=agent_config,
|
||||
firecrawl_api_key=firecrawl_api_key,
|
||||
thread_visibility=visibility,
|
||||
sandbox_backend=sandbox_backend,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0
|
||||
)
|
||||
|
||||
# Build input with message history
|
||||
langchain_messages = []
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
# Bootstrap history for cloned chats (no LangGraph checkpoint exists yet)
|
||||
if needs_history_bootstrap:
|
||||
langchain_messages = await bootstrap_history_from_db(
|
||||
session, chat_id, thread_visibility=visibility
|
||||
)
|
||||
|
||||
# Clear the flag so we don't bootstrap again on next message
|
||||
from app.db import NewChatThread
|
||||
|
||||
thread_result = await session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == chat_id)
|
||||
)
|
||||
|
|
@ -1012,11 +1178,9 @@ async def stream_new_chat(
|
|||
# Fetch mentioned documents if any (with chunks for proper citations)
|
||||
mentioned_documents: list[Document] = []
|
||||
if mentioned_document_ids:
|
||||
from sqlalchemy.orm import selectinload as doc_selectinload
|
||||
|
||||
result = await session.execute(
|
||||
select(Document)
|
||||
.options(doc_selectinload(Document.chunks))
|
||||
.options(selectinload(Document.chunks))
|
||||
.filter(
|
||||
Document.id.in_(mentioned_document_ids),
|
||||
Document.search_space_id == search_space_id,
|
||||
|
|
@ -1027,8 +1191,6 @@ async def stream_new_chat(
|
|||
# Fetch mentioned SurfSense docs if any
|
||||
mentioned_surfsense_docs: list[SurfsenseDocsDocument] = []
|
||||
if mentioned_surfsense_doc_ids:
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
result = await session.execute(
|
||||
select(SurfsenseDocsDocument)
|
||||
.options(selectinload(SurfsenseDocsDocument.chunks))
|
||||
|
|
@ -1112,6 +1274,11 @@ async def stream_new_chat(
|
|||
"search_space_id": search_space_id,
|
||||
}
|
||||
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] History bootstrap + doc/report queries in %.3fs",
|
||||
time.perf_counter() - _t0,
|
||||
)
|
||||
|
||||
# All pre-streaming DB reads are done. Commit to release the
|
||||
# transaction and its ACCESS SHARE locks so we don't block DDL
|
||||
# (e.g. migrations) for the entire duration of LLM streaming.
|
||||
|
|
@ -1119,6 +1286,12 @@ async def stream_new_chat(
|
|||
# short-lived transactions (or use isolated sessions).
|
||||
await session.commit()
|
||||
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] Total pre-stream setup in %.3fs (chat_id=%s)",
|
||||
time.perf_counter() - _t_total,
|
||||
chat_id,
|
||||
)
|
||||
|
||||
# Configure LangGraph with thread_id for memory
|
||||
# If checkpoint_id is provided, fork from that checkpoint (for edit/reload)
|
||||
configurable = {"thread_id": str(chat_id)}
|
||||
|
|
@ -1180,7 +1353,8 @@ async def stream_new_chat(
|
|||
items=initial_items,
|
||||
)
|
||||
|
||||
stream_result = StreamResult()
|
||||
_t_stream_start = time.perf_counter()
|
||||
_first_event_logged = False
|
||||
async for sse in _stream_agent_events(
|
||||
agent=agent,
|
||||
config=config,
|
||||
|
|
@ -1192,8 +1366,23 @@ async def stream_new_chat(
|
|||
initial_step_title=initial_title,
|
||||
initial_step_items=initial_items,
|
||||
):
|
||||
if not _first_event_logged:
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] First agent event in %.3fs (time since stream start), "
|
||||
"%.3fs (total since request start) (chat_id=%s)",
|
||||
time.perf_counter() - _t_stream_start,
|
||||
time.perf_counter() - _t_total,
|
||||
chat_id,
|
||||
)
|
||||
_first_event_logged = True
|
||||
yield sse
|
||||
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] Agent stream completed in %.3fs (chat_id=%s)",
|
||||
time.perf_counter() - _t_stream_start,
|
||||
chat_id,
|
||||
)
|
||||
|
||||
if stream_result.is_interrupted:
|
||||
yield streaming_service.format_finish_step()
|
||||
yield streaming_service.format_finish()
|
||||
|
|
@ -1202,12 +1391,6 @@ async def stream_new_chat(
|
|||
|
||||
accumulated_text = stream_result.accumulated_text
|
||||
|
||||
# Generate LLM title for new chats after first response
|
||||
# Check if this is the first assistant response by counting existing assistant messages
|
||||
from sqlalchemy import func
|
||||
|
||||
from app.db import NewChatMessage, NewChatThread
|
||||
|
||||
assistant_count_result = await session.execute(
|
||||
select(func.count(NewChatMessage.id)).filter(
|
||||
NewChatMessage.thread_id == chat_id,
|
||||
|
|
@ -1294,6 +1477,8 @@ async def stream_new_chat(
|
|||
"Failed to clear AI responding state for thread %s", chat_id
|
||||
)
|
||||
|
||||
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
|
||||
|
||||
|
||||
async def stream_resume_chat(
|
||||
chat_id: int,
|
||||
|
|
@ -1305,12 +1490,15 @@ async def stream_resume_chat(
|
|||
thread_visibility: ChatVisibility | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
streaming_service = VercelStreamingService()
|
||||
stream_result = StreamResult()
|
||||
_t_total = time.perf_counter()
|
||||
|
||||
try:
|
||||
if user_id:
|
||||
await set_ai_responding(session, chat_id, UUID(user_id))
|
||||
|
||||
agent_config: AgentConfig | None = None
|
||||
_t0 = time.perf_counter()
|
||||
if llm_config_id >= 0:
|
||||
agent_config = await load_agent_config(
|
||||
session=session,
|
||||
|
|
@ -1334,26 +1522,54 @@ async def stream_resume_chat(
|
|||
return
|
||||
llm = create_chat_litellm_from_config(llm_config)
|
||||
agent_config = AgentConfig.from_yaml_config(llm_config)
|
||||
_perf_log.info(
|
||||
"[stream_resume] LLM config loaded in %.3fs", time.perf_counter() - _t0
|
||||
)
|
||||
|
||||
if not llm:
|
||||
yield streaming_service.format_error("Failed to create LLM instance")
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
connector_service = ConnectorService(session, search_space_id=search_space_id)
|
||||
|
||||
from app.db import SearchSourceConnectorType
|
||||
|
||||
firecrawl_api_key = None
|
||||
webcrawler_connector = await connector_service.get_connector_by_type(
|
||||
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id
|
||||
)
|
||||
if webcrawler_connector and webcrawler_connector.config:
|
||||
firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY")
|
||||
_perf_log.info(
|
||||
"[stream_resume] Connector service + firecrawl key in %.3fs",
|
||||
time.perf_counter() - _t0,
|
||||
)
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
checkpointer = await get_checkpointer()
|
||||
_perf_log.info(
|
||||
"[stream_resume] Checkpointer ready in %.3fs", time.perf_counter() - _t0
|
||||
)
|
||||
|
||||
sandbox_backend = None
|
||||
_t0 = time.perf_counter()
|
||||
if is_sandbox_enabled():
|
||||
try:
|
||||
sandbox_backend = await get_or_create_sandbox(chat_id)
|
||||
except Exception as sandbox_err:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Sandbox creation failed, continuing without execute tool: %s",
|
||||
sandbox_err,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[stream_resume] Sandbox provisioning in %.3fs (enabled=%s)",
|
||||
time.perf_counter() - _t0,
|
||||
sandbox_backend is not None,
|
||||
)
|
||||
|
||||
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
agent = await create_surfsense_deep_agent(
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
|
|
@ -1365,11 +1581,21 @@ async def stream_resume_chat(
|
|||
agent_config=agent_config,
|
||||
firecrawl_api_key=firecrawl_api_key,
|
||||
thread_visibility=visibility,
|
||||
sandbox_backend=sandbox_backend,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0
|
||||
)
|
||||
|
||||
# Release the transaction before streaming (same rationale as stream_new_chat).
|
||||
await session.commit()
|
||||
|
||||
_perf_log.info(
|
||||
"[stream_resume] Total pre-stream setup in %.3fs (chat_id=%s)",
|
||||
time.perf_counter() - _t_total,
|
||||
chat_id,
|
||||
)
|
||||
|
||||
from langgraph.types import Command
|
||||
|
||||
config = {
|
||||
|
|
@ -1380,7 +1606,8 @@ async def stream_resume_chat(
|
|||
yield streaming_service.format_message_start()
|
||||
yield streaming_service.format_start_step()
|
||||
|
||||
stream_result = StreamResult()
|
||||
_t_stream_start = time.perf_counter()
|
||||
_first_event_logged = False
|
||||
async for sse in _stream_agent_events(
|
||||
agent=agent,
|
||||
config=config,
|
||||
|
|
@ -1389,7 +1616,20 @@ async def stream_resume_chat(
|
|||
result=stream_result,
|
||||
step_prefix="thinking-resume",
|
||||
):
|
||||
if not _first_event_logged:
|
||||
_perf_log.info(
|
||||
"[stream_resume] First agent event in %.3fs (stream), %.3fs (total) (chat_id=%s)",
|
||||
time.perf_counter() - _t_stream_start,
|
||||
time.perf_counter() - _t_total,
|
||||
chat_id,
|
||||
)
|
||||
_first_event_logged = True
|
||||
yield sse
|
||||
_perf_log.info(
|
||||
"[stream_resume] Agent stream completed in %.3fs (chat_id=%s)",
|
||||
time.perf_counter() - _t_stream_start,
|
||||
chat_id,
|
||||
)
|
||||
if stream_result.is_interrupted:
|
||||
yield streaming_service.format_finish_step()
|
||||
yield streaming_service.format_finish()
|
||||
|
|
@ -1423,3 +1663,5 @@ async def stream_resume_chat(
|
|||
logging.getLogger(__name__).warning(
|
||||
"Failed to clear AI responding state for thread %s", chat_id
|
||||
)
|
||||
|
||||
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
|
||||
|
|
|
|||
|
|
@ -12,13 +12,13 @@ from collections.abc import Awaitable, Callable
|
|||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.airtable_history import AirtableHistoryConnector
|
||||
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
|
|
@ -399,7 +399,7 @@ async def index_airtable_records(
|
|||
session, user_id, search_space_id
|
||||
)
|
||||
|
||||
if user_llm:
|
||||
if user_llm and connector.enable_summary:
|
||||
document_metadata_for_summary = {
|
||||
"record_id": item["record_id"],
|
||||
"created_time": item["record"].get("CREATED_TIME()", ""),
|
||||
|
|
@ -415,11 +415,8 @@ async def index_airtable_records(
|
|||
document_metadata_for_summary,
|
||||
)
|
||||
else:
|
||||
# Fallback to simple summary if no LLM configured
|
||||
summary_content = f"Airtable Record: {item['record_id']}\n\n"
|
||||
summary_embedding = config.embedding_model_instance.embed(
|
||||
summary_content
|
||||
)
|
||||
summary_content = f"Airtable Record: {item['record_id']}\n\n{item['markdown_content']}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(item["markdown_content"])
|
||||
|
||||
|
|
|
|||
|
|
@ -13,13 +13,13 @@ from datetime import datetime
|
|||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.bookstack_connector import BookStackConnector
|
||||
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
|
|
@ -403,7 +403,7 @@ async def index_bookstack_pages(
|
|||
"connector_id": connector_id,
|
||||
}
|
||||
|
||||
if user_llm:
|
||||
if user_llm and connector.enable_summary:
|
||||
summary_metadata = {
|
||||
"page_name": item["page_name"],
|
||||
"page_id": item["page_id"],
|
||||
|
|
@ -418,17 +418,8 @@ async def index_bookstack_pages(
|
|||
item["full_content"], user_llm, summary_metadata
|
||||
)
|
||||
else:
|
||||
# Fallback to simple summary if no LLM configured
|
||||
summary_content = f"BookStack Page: {item['page_name']}\n\nBook ID: {item['book_id']}\n\n"
|
||||
if item["page_content"]:
|
||||
# Take first 1000 characters of content for summary
|
||||
content_preview = item["page_content"][:1000]
|
||||
if len(item["page_content"]) > 1000:
|
||||
content_preview += "..."
|
||||
summary_content += f"Content Preview: {content_preview}\n\n"
|
||||
summary_embedding = config.embedding_model_instance.embed(
|
||||
summary_content
|
||||
)
|
||||
summary_content = f"BookStack Page: {item['page_name']}\n\nBook ID: {item['book_id']}\n\n{item['full_content']}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
# Process chunks - using the full page content
|
||||
chunks = await create_document_chunks(item["full_content"])
|
||||
|
|
|
|||
|
|
@ -14,13 +14,13 @@ from datetime import datetime
|
|||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.clickup_history import ClickUpHistoryConnector
|
||||
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
|
|
@ -398,7 +398,7 @@ async def index_clickup_tasks(
|
|||
session, user_id, search_space_id
|
||||
)
|
||||
|
||||
if user_llm:
|
||||
if user_llm and connector.enable_summary:
|
||||
document_metadata_for_summary = {
|
||||
"task_id": item["task_id"],
|
||||
"task_name": item["task_name"],
|
||||
|
|
@ -418,9 +418,7 @@ async def index_clickup_tasks(
|
|||
)
|
||||
else:
|
||||
summary_content = item["task_content"]
|
||||
summary_embedding = config.embedding_model_instance.embed(
|
||||
item["task_content"]
|
||||
)
|
||||
summary_embedding = embed_text(item["task_content"])
|
||||
|
||||
chunks = await create_document_chunks(item["task_content"])
|
||||
|
||||
|
|
|
|||
|
|
@ -14,13 +14,13 @@ from datetime import datetime
|
|||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
||||
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
|
|
@ -378,7 +378,7 @@ async def index_confluence_pages(
|
|||
session, user_id, search_space_id
|
||||
)
|
||||
|
||||
if user_llm:
|
||||
if user_llm and connector.enable_summary:
|
||||
document_metadata = {
|
||||
"page_title": item["page_title"],
|
||||
"page_id": item["page_id"],
|
||||
|
|
@ -394,18 +394,8 @@ async def index_confluence_pages(
|
|||
item["full_content"], user_llm, document_metadata
|
||||
)
|
||||
else:
|
||||
# Fallback to simple summary if no LLM configured
|
||||
summary_content = f"Confluence Page: {item['page_title']}\n\nSpace ID: {item['space_id']}\n\n"
|
||||
if item["page_content"]:
|
||||
# Take first 1000 characters of content for summary
|
||||
content_preview = item["page_content"][:1000]
|
||||
if len(item["page_content"]) > 1000:
|
||||
content_preview += "..."
|
||||
summary_content += f"Content Preview: {content_preview}\n\n"
|
||||
summary_content += f"Comments: {item['comment_count']}"
|
||||
summary_embedding = config.embedding_model_instance.embed(
|
||||
summary_content
|
||||
)
|
||||
summary_content = f"Confluence Page: {item['page_title']}\n\nSpace ID: {item['space_id']}\n\n{item['full_content']}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
# Process chunks - using the full page content with comments
|
||||
chunks = await create_document_chunks(item["full_content"])
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnector
|
|||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_unique_identifier_hash,
|
||||
)
|
||||
|
|
@ -669,9 +670,7 @@ async def index_discord_messages(
|
|||
|
||||
# Heavy processing (embeddings, chunks)
|
||||
chunks = await create_document_chunks(item["combined_document_string"])
|
||||
doc_embedding = config.embedding_model_instance.embed(
|
||||
item["combined_document_string"]
|
||||
)
|
||||
doc_embedding = embed_text(item["combined_document_string"])
|
||||
|
||||
# Update document to READY with actual content
|
||||
document.title = f"{item['guild_name']}#{item['channel_name']}"
|
||||
|
|
|
|||
|
|
@ -16,13 +16,13 @@ from datetime import UTC, datetime
|
|||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.github_connector import GitHubConnector
|
||||
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
|
|
@ -367,7 +367,7 @@ async def index_github_repos(
|
|||
"estimated_tokens": digest.estimated_tokens,
|
||||
}
|
||||
|
||||
if user_llm:
|
||||
if user_llm and connector.enable_summary:
|
||||
# Prepare content for summarization
|
||||
summary_content = digest.full_digest
|
||||
if len(summary_content) > MAX_DIGEST_CHARS:
|
||||
|
|
@ -381,15 +381,12 @@ async def index_github_repos(
|
|||
summary_content, user_llm, document_metadata_for_summary
|
||||
)
|
||||
else:
|
||||
# Fallback to simple summary if no LLM configured
|
||||
summary_text = (
|
||||
f"# GitHub Repository: {repo_full_name}\n\n"
|
||||
f"## Summary\n{digest.summary}\n\n"
|
||||
f"## File Structure\n{digest.tree[:3000]}"
|
||||
)
|
||||
summary_embedding = config.embedding_model_instance.embed(
|
||||
summary_text
|
||||
f"## File Structure\n{digest.tree}"
|
||||
)
|
||||
summary_embedding = embed_text(summary_text)
|
||||
|
||||
# Chunk the full digest content for granular search
|
||||
try:
|
||||
|
|
@ -551,7 +548,7 @@ async def _simple_chunk_content(content: str, chunk_size: int = 4000) -> list:
|
|||
chunks.append(
|
||||
Chunk(
|
||||
content=chunk_text,
|
||||
embedding=config.embedding_model_instance.embed(chunk_text),
|
||||
embedding=embed_text(chunk_text),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from app.services.llm_service import get_user_long_context_llm
|
|||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
|
|
@ -489,7 +490,7 @@ async def index_google_calendar_events(
|
|||
session, user_id, search_space_id
|
||||
)
|
||||
|
||||
if user_llm:
|
||||
if user_llm and connector.enable_summary:
|
||||
document_metadata_for_summary = {
|
||||
"event_id": item["event_id"],
|
||||
"event_summary": item["event_summary"],
|
||||
|
|
@ -507,22 +508,8 @@ async def index_google_calendar_events(
|
|||
item["event_markdown"], user_llm, document_metadata_for_summary
|
||||
)
|
||||
else:
|
||||
summary_content = (
|
||||
f"Google Calendar Event: {item['event_summary']}\n\n"
|
||||
)
|
||||
summary_content += f"Calendar: {item['calendar_id']}\n"
|
||||
summary_content += f"Start: {item['start_time']}\n"
|
||||
summary_content += f"End: {item['end_time']}\n"
|
||||
if item["location"]:
|
||||
summary_content += f"Location: {item['location']}\n"
|
||||
if item["description"]:
|
||||
desc_preview = item["description"][:1000]
|
||||
if len(item["description"]) > 1000:
|
||||
desc_preview += "..."
|
||||
summary_content += f"Description: {desc_preview}\n"
|
||||
summary_embedding = config.embedding_model_instance.embed(
|
||||
summary_content
|
||||
)
|
||||
summary_content = f"Google Calendar Event: {item['event_summary']}\n\n{item['event_markdown']}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(item["event_markdown"])
|
||||
|
||||
|
|
|
|||
|
|
@ -352,7 +352,7 @@ async def index_google_drive_single_file(
|
|||
await session.commit()
|
||||
|
||||
# Process the file
|
||||
indexed, skipped, failed = await _process_single_file(
|
||||
indexed, _skipped, failed = await _process_single_file(
|
||||
drive_client=drive_client,
|
||||
session=session,
|
||||
file=file,
|
||||
|
|
@ -608,7 +608,7 @@ async def _index_with_delta_sync(
|
|||
{"stage": "delta_sync", "start_token": start_page_token},
|
||||
)
|
||||
|
||||
changes, final_token, error = await fetch_all_changes(
|
||||
changes, _final_token, error = await fetch_all_changes(
|
||||
drive_client, start_page_token, folder_id
|
||||
)
|
||||
|
||||
|
|
@ -1011,7 +1011,7 @@ async def _process_single_file(
|
|||
pending_document.status = DocumentStatus.processing()
|
||||
await session.commit()
|
||||
|
||||
_, error, metadata = await download_and_process_file(
|
||||
_, error, _metadata = await download_and_process_file(
|
||||
client=drive_client,
|
||||
file=file,
|
||||
search_space_id=search_space_id,
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ from app.services.llm_service import get_user_long_context_llm
|
|||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
|
|
@ -413,7 +414,7 @@ async def index_google_gmail_messages(
|
|||
session, user_id, search_space_id
|
||||
)
|
||||
|
||||
if user_llm:
|
||||
if user_llm and connector.enable_summary:
|
||||
document_metadata_for_summary = {
|
||||
"message_id": item["message_id"],
|
||||
"thread_id": item["thread_id"],
|
||||
|
|
@ -432,12 +433,8 @@ async def index_google_gmail_messages(
|
|||
document_metadata_for_summary,
|
||||
)
|
||||
else:
|
||||
summary_content = f"Google Gmail Message: {item['subject']}\n\n"
|
||||
summary_content += f"Sender: {item['sender']}\n"
|
||||
summary_content += f"Date: {item['date_str']}\n"
|
||||
summary_embedding = config.embedding_model_instance.embed(
|
||||
summary_content
|
||||
)
|
||||
summary_content = f"Google Gmail Message: {item['subject']}\n\nFrom: {item['sender']}\nDate: {item['date_str']}\n\n{item['markdown_content']}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(item["markdown_content"])
|
||||
|
||||
|
|
|
|||
|
|
@ -14,13 +14,13 @@ from datetime import datetime
|
|||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.jira_history import JiraHistoryConnector
|
||||
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
|
|
@ -356,7 +356,7 @@ async def index_jira_issues(
|
|||
session, user_id, search_space_id
|
||||
)
|
||||
|
||||
if user_llm:
|
||||
if user_llm and connector.enable_summary:
|
||||
document_metadata = {
|
||||
"issue_key": item["issue_identifier"],
|
||||
"issue_title": item["issue_title"],
|
||||
|
|
@ -373,14 +373,8 @@ async def index_jira_issues(
|
|||
item["issue_content"], user_llm, document_metadata
|
||||
)
|
||||
else:
|
||||
# Fallback to simple summary if no LLM configured
|
||||
summary_content = f"Jira Issue {item['issue_identifier']}: {item['issue_title']}\n\nStatus: {item['formatted_issue'].get('status', 'Unknown')}\n\n"
|
||||
if item["formatted_issue"].get("description"):
|
||||
summary_content += f"Description: {item['formatted_issue'].get('description')}\n\n"
|
||||
summary_content += f"Comments: {item['comment_count']}"
|
||||
summary_embedding = config.embedding_model_instance.embed(
|
||||
summary_content
|
||||
)
|
||||
summary_content = f"Jira Issue {item['issue_identifier']}: {item['issue_title']}\n\n{item['issue_content']}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
# Process chunks - using the full issue content with comments
|
||||
chunks = await create_document_chunks(item["issue_content"])
|
||||
|
|
|
|||
|
|
@ -13,13 +13,13 @@ from datetime import datetime
|
|||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.linear_connector import LinearConnector
|
||||
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
|
|
@ -395,7 +395,7 @@ async def index_linear_issues(
|
|||
session, user_id, search_space_id
|
||||
)
|
||||
|
||||
if user_llm:
|
||||
if user_llm and connector.enable_summary:
|
||||
document_metadata_for_summary = {
|
||||
"issue_id": item["issue_identifier"],
|
||||
"issue_title": item["issue_title"],
|
||||
|
|
@ -412,17 +412,8 @@ async def index_linear_issues(
|
|||
item["issue_content"], user_llm, document_metadata_for_summary
|
||||
)
|
||||
else:
|
||||
# Fallback to simple summary if no LLM configured
|
||||
description = item["description"]
|
||||
if description and len(description) > 1000:
|
||||
description = description[:997] + "..."
|
||||
summary_content = f"Linear Issue {item['issue_identifier']}: {item['issue_title']}\n\nStatus: {item['state']}\n\n"
|
||||
if description:
|
||||
summary_content += f"Description: {description}\n\n"
|
||||
summary_content += f"Comments: {item['comment_count']}"
|
||||
summary_embedding = config.embedding_model_instance.embed(
|
||||
summary_content
|
||||
)
|
||||
summary_content = f"Linear Issue {item['issue_identifier']}: {item['issue_title']}\n\nStatus: {item['state']}\n\n{item['issue_content']}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(item["issue_content"])
|
||||
|
||||
|
|
|
|||
|
|
@ -13,13 +13,13 @@ from datetime import datetime, timedelta
|
|||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.luma_connector import LumaConnector
|
||||
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
|
|
@ -441,7 +441,7 @@ async def index_luma_events(
|
|||
session, user_id, search_space_id
|
||||
)
|
||||
|
||||
if user_llm:
|
||||
if user_llm and connector.enable_summary:
|
||||
document_metadata_for_summary = {
|
||||
"event_id": item["event_id"],
|
||||
"event_name": item["event_name"],
|
||||
|
|
@ -462,29 +462,10 @@ async def index_luma_events(
|
|||
item["event_markdown"], user_llm, document_metadata_for_summary
|
||||
)
|
||||
else:
|
||||
# Fallback to simple summary if no LLM configured
|
||||
summary_content = f"Luma Event: {item['event_name']}\n\n"
|
||||
if item["event_url"]:
|
||||
summary_content += f"URL: {item['event_url']}\n"
|
||||
summary_content += f"Start: {item['start_at']}\n"
|
||||
summary_content += f"End: {item['end_at']}\n"
|
||||
if item["timezone"]:
|
||||
summary_content += f"Timezone: {item['timezone']}\n"
|
||||
if item["location"]:
|
||||
summary_content += f"Location: {item['location']}\n"
|
||||
if item["city"]:
|
||||
summary_content += f"City: {item['city']}\n"
|
||||
if item["host_names"]:
|
||||
summary_content += f"Hosts: {item['host_names']}\n"
|
||||
if item["description"]:
|
||||
desc_preview = item["description"][:1000]
|
||||
if len(item["description"]) > 1000:
|
||||
desc_preview += "..."
|
||||
summary_content += f"Description: {desc_preview}\n"
|
||||
|
||||
summary_embedding = config.embedding_model_instance.embed(
|
||||
summary_content
|
||||
summary_content = (
|
||||
f"Luma Event: {item['event_name']}\n\n{item['event_markdown']}"
|
||||
)
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(item["event_markdown"])
|
||||
|
||||
|
|
|
|||
|
|
@ -13,13 +13,13 @@ from datetime import datetime
|
|||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.notion_history import NotionHistoryConnector
|
||||
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
|
|
@ -447,7 +447,7 @@ async def index_notion_pages(
|
|||
session, user_id, search_space_id
|
||||
)
|
||||
|
||||
if user_llm:
|
||||
if user_llm and connector.enable_summary:
|
||||
document_metadata_for_summary = {
|
||||
"page_title": item["page_title"],
|
||||
"page_id": item["page_id"],
|
||||
|
|
@ -463,11 +463,8 @@ async def index_notion_pages(
|
|||
document_metadata_for_summary,
|
||||
)
|
||||
else:
|
||||
# Fallback to simple summary if no LLM configured
|
||||
summary_content = f"Notion Page: {item['page_title']}\n\n{item['markdown_content'][:500]}..."
|
||||
summary_embedding = config.embedding_model_instance.embed(
|
||||
summary_content
|
||||
)
|
||||
summary_content = f"Notion Page: {item['page_title']}\n\n{item['markdown_content']}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(item["markdown_content"])
|
||||
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ from app.services.llm_service import get_user_long_context_llm
|
|||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
|
|
@ -546,7 +547,7 @@ async def index_obsidian_vault(
|
|||
|
||||
# Generate summary
|
||||
summary_content = ""
|
||||
if long_context_llm:
|
||||
if long_context_llm and connector.enable_summary:
|
||||
summary_content, _ = await generate_document_summary(
|
||||
document_string,
|
||||
long_context_llm,
|
||||
|
|
@ -554,7 +555,7 @@ async def index_obsidian_vault(
|
|||
)
|
||||
|
||||
# Generate embedding
|
||||
embedding = config.embedding_model_instance.embed(document_string)
|
||||
embedding = embed_text(document_string)
|
||||
|
||||
# Add URL and summary to metadata
|
||||
document_metadata["url"] = f"obsidian://{vault_name}/{relative_path}"
|
||||
|
|
|
|||
|
|
@ -17,12 +17,12 @@ from slack_sdk.errors import SlackApiError
|
|||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.slack_history import SlackHistory
|
||||
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_unique_identifier_hash,
|
||||
)
|
||||
|
|
@ -542,9 +542,7 @@ async def index_slack_messages(
|
|||
|
||||
# Heavy processing (embeddings, chunks)
|
||||
chunks = await create_document_chunks(item["combined_document_string"])
|
||||
doc_embedding = config.embedding_model_instance.embed(
|
||||
item["combined_document_string"]
|
||||
)
|
||||
doc_embedding = embed_text(item["combined_document_string"])
|
||||
|
||||
# Update document to READY with actual content
|
||||
document.title = f"{item['team_name']}#{item['channel_name']}"
|
||||
|
|
|
|||
|
|
@ -16,12 +16,12 @@ from datetime import UTC, datetime
|
|||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.teams_history import TeamsHistory
|
||||
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_unique_identifier_hash,
|
||||
)
|
||||
|
|
@ -581,9 +581,7 @@ async def index_teams_messages(
|
|||
|
||||
# Heavy processing (embeddings, chunks)
|
||||
chunks = await create_document_chunks(item["combined_document_string"])
|
||||
doc_embedding = config.embedding_model_instance.embed(
|
||||
item["combined_document_string"]
|
||||
)
|
||||
doc_embedding = embed_text(item["combined_document_string"])
|
||||
|
||||
# Update document to READY with actual content
|
||||
document.title = f"{item['team_name']} - {item['channel_name']}"
|
||||
|
|
|
|||
|
|
@ -13,13 +13,13 @@ from datetime import datetime
|
|||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.webcrawler_connector import WebCrawlerConnector
|
||||
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
|
|
@ -377,7 +377,7 @@ async def index_crawled_urls(
|
|||
session, user_id, search_space_id
|
||||
)
|
||||
|
||||
if user_llm:
|
||||
if user_llm and connector.enable_summary:
|
||||
document_metadata_for_summary = {
|
||||
"url": url,
|
||||
"title": title,
|
||||
|
|
@ -393,24 +393,8 @@ async def index_crawled_urls(
|
|||
structured_document, user_llm, document_metadata_for_summary
|
||||
)
|
||||
else:
|
||||
# Fallback to simple summary if no LLM configured
|
||||
summary_content = f"Crawled URL: {title}\n\n"
|
||||
summary_content += f"URL: {url}\n"
|
||||
if description:
|
||||
summary_content += f"Description: {description}\n"
|
||||
if language:
|
||||
summary_content += f"Language: {language}\n"
|
||||
summary_content += f"Crawler: {crawler_type}\n\n"
|
||||
|
||||
# Add content preview
|
||||
content_preview = content[:1000]
|
||||
if len(content) > 1000:
|
||||
content_preview += "..."
|
||||
summary_content += f"Content Preview:\n{content_preview}\n"
|
||||
|
||||
summary_embedding = config.embedding_model_instance.embed(
|
||||
summary_content
|
||||
)
|
||||
summary_content = f"Crawled URL: {title}\n\nURL: {url}\n\n{content}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
# Process chunks
|
||||
chunks = await create_document_chunks(content)
|
||||
|
|
|
|||
|
|
@ -18,12 +18,14 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.config import config as app_config
|
||||
from app.db import Document, DocumentStatus, DocumentType, Log, Notification
|
||||
from app.indexing_pipeline.adapters.file_upload_adapter import index_uploaded_file
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.notification_service import NotificationService
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.utils.document_converters import (
|
||||
convert_document_to_markdown,
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
|
|
@ -33,7 +35,6 @@ from .base import (
|
|||
check_document_by_unique_identifier,
|
||||
check_duplicate_document,
|
||||
get_current_timestamp,
|
||||
safe_set_chunks,
|
||||
)
|
||||
from .markdown_processor import add_received_markdown_file_document
|
||||
|
||||
|
|
@ -760,11 +761,7 @@ async def add_received_file_document_using_docling(
|
|||
f"{metadata_section}\n\n# DOCUMENT SUMMARY\n\n{summary_content}"
|
||||
)
|
||||
|
||||
from app.config import config
|
||||
|
||||
summary_embedding = config.embedding_model_instance.embed(
|
||||
enhanced_summary_content
|
||||
)
|
||||
summary_embedding = embed_text(enhanced_summary_content)
|
||||
|
||||
# Process chunks
|
||||
chunks = await create_document_chunks(file_in_markdown)
|
||||
|
|
@ -1599,6 +1596,7 @@ async def process_file_in_background_with_document(
|
|||
log_entry: Log,
|
||||
connector: dict | None = None,
|
||||
notification: Notification | None = None,
|
||||
should_summarize: bool = False,
|
||||
) -> Document | None:
|
||||
"""
|
||||
Process file and update existing pending document (2-phase pattern).
|
||||
|
|
@ -1632,6 +1630,8 @@ async def process_file_in_background_with_document(
|
|||
from app.config import config as app_config
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
doc_id = document.id
|
||||
|
||||
try:
|
||||
markdown_content = None
|
||||
etl_service = None
|
||||
|
|
@ -1855,7 +1855,7 @@ async def process_file_in_background_with_document(
|
|||
content_hash = generate_content_hash(markdown_content, search_space_id)
|
||||
|
||||
existing_by_content = await check_duplicate_document(session, content_hash)
|
||||
if existing_by_content and existing_by_content.id != document.id:
|
||||
if existing_by_content and existing_by_content.id != doc_id:
|
||||
# Duplicate content found - mark this document as failed
|
||||
logging.info(
|
||||
f"Duplicate content detected for {filename}, "
|
||||
|
|
@ -1863,7 +1863,7 @@ async def process_file_in_background_with_document(
|
|||
)
|
||||
return None
|
||||
|
||||
# ===== STEP 3: Generate embeddings and chunks =====
|
||||
# ===== STEP 3+4: Index via pipeline =====
|
||||
if notification:
|
||||
await NotificationService.document_processing.notify_processing_progress(
|
||||
session, notification, stage="chunking"
|
||||
|
|
@ -1871,57 +1871,24 @@ async def process_file_in_background_with_document(
|
|||
|
||||
user_llm = await get_user_long_context_llm(session, user_id, search_space_id)
|
||||
|
||||
if user_llm:
|
||||
document_metadata = {
|
||||
"file_name": filename,
|
||||
"etl_service": etl_service,
|
||||
"document_type": "File Document",
|
||||
}
|
||||
summary_content, summary_embedding = await generate_document_summary(
|
||||
markdown_content, user_llm, document_metadata
|
||||
)
|
||||
else:
|
||||
# Fallback: use truncated content as summary
|
||||
summary_content = markdown_content[:4000]
|
||||
from app.config import config
|
||||
|
||||
summary_embedding = config.embedding_model_instance.embed(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(markdown_content)
|
||||
|
||||
# ===== STEP 4: Update document to READY =====
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
document.title = filename
|
||||
document.content = summary_content
|
||||
document.content_hash = content_hash
|
||||
document.embedding = summary_embedding
|
||||
document.document_metadata = {
|
||||
"FILE_NAME": filename,
|
||||
"ETL_SERVICE": etl_service or "UNKNOWN",
|
||||
**(document.document_metadata or {}),
|
||||
}
|
||||
flag_modified(document, "document_metadata")
|
||||
|
||||
# Use safe_set_chunks to avoid async issues
|
||||
safe_set_chunks(document, chunks)
|
||||
|
||||
document.source_markdown = markdown_content
|
||||
document.content_needs_reindexing = False
|
||||
document.updated_at = get_current_timestamp()
|
||||
document.status = DocumentStatus.ready() # Shows checkmark in UI
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(document)
|
||||
await index_uploaded_file(
|
||||
markdown_content=markdown_content,
|
||||
filename=filename,
|
||||
etl_service=etl_service,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
llm=user_llm,
|
||||
should_summarize=should_summarize,
|
||||
)
|
||||
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Successfully processed file: {filename}",
|
||||
{
|
||||
"document_id": document.id,
|
||||
"document_id": doc_id,
|
||||
"content_hash": content_hash,
|
||||
"file_type": etl_service,
|
||||
"chunks_count": len(chunks),
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -1946,7 +1913,7 @@ async def process_file_in_background_with_document(
|
|||
{
|
||||
"error_type": type(e).__name__,
|
||||
"filename": filename,
|
||||
"document_id": document.id,
|
||||
"document_id": doc_id,
|
||||
},
|
||||
)
|
||||
logging.error(f"Error processing file with document: {error_message}")
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from sqlalchemy.orm import selectinload
|
|||
|
||||
from app.config import config
|
||||
from app.db import SurfsenseDocsChunk, SurfsenseDocsDocument, async_session_maker
|
||||
from app.utils.document_converters import embed_text
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -89,7 +90,7 @@ def create_surfsense_docs_chunks(content: str) -> list[SurfsenseDocsChunk]:
|
|||
return [
|
||||
SurfsenseDocsChunk(
|
||||
content=chunk.text,
|
||||
embedding=config.embedding_model_instance.embed(chunk.text),
|
||||
embedding=embed_text(chunk.text),
|
||||
)
|
||||
for chunk in config.chunker_instance.chunk(content)
|
||||
]
|
||||
|
|
@ -154,7 +155,7 @@ async def index_surfsense_docs(session: AsyncSession) -> tuple[int, int, int, in
|
|||
existing_doc.title = title
|
||||
existing_doc.content = content
|
||||
existing_doc.content_hash = content_hash
|
||||
existing_doc.embedding = config.embedding_model_instance.embed(content)
|
||||
existing_doc.embedding = embed_text(content)
|
||||
existing_doc.chunks = chunks
|
||||
existing_doc.updated_at = datetime.now(UTC)
|
||||
|
||||
|
|
@ -170,7 +171,7 @@ async def index_surfsense_docs(session: AsyncSession) -> tuple[int, int, int, in
|
|||
title=title,
|
||||
content=content,
|
||||
content_hash=content_hash,
|
||||
embedding=config.embedding_model_instance.embed(content),
|
||||
embedding=embed_text(content),
|
||||
chunks=chunks,
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,11 +1,59 @@
|
|||
import hashlib
|
||||
import logging
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
from litellm import get_model_info, token_counter
|
||||
|
||||
from app.config import config
|
||||
from app.db import Chunk, DocumentType
|
||||
from app.prompts import SUMMARY_PROMPT_TEMPLATE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_embedding_max_tokens() -> int:
|
||||
"""Get the max token limit for the configured embedding model.
|
||||
|
||||
Checks model properties in order: max_seq_length, _max_tokens.
|
||||
Falls back to 8192 (OpenAI embedding default).
|
||||
"""
|
||||
model = config.embedding_model_instance
|
||||
for attr in ("max_seq_length", "_max_tokens"):
|
||||
val = getattr(model, attr, None)
|
||||
if isinstance(val, int) and val > 0:
|
||||
return val
|
||||
return 8192
|
||||
|
||||
|
||||
def truncate_for_embedding(text: str) -> str:
|
||||
"""Truncate text to fit within the embedding model's context window.
|
||||
|
||||
Uses the embedding model's own tokenizer for accurate token counting,
|
||||
so the result is model-agnostic regardless of the underlying provider.
|
||||
"""
|
||||
max_tokens = _get_embedding_max_tokens()
|
||||
if len(text) // 3 <= max_tokens:
|
||||
return text
|
||||
|
||||
tokenizer = config.embedding_model_instance.get_tokenizer()
|
||||
tokens = tokenizer.encode(text)
|
||||
if len(tokens) <= max_tokens:
|
||||
return text
|
||||
|
||||
warnings.warn(
|
||||
f"Truncating text from {len(tokens)} to {max_tokens} tokens for embedding.",
|
||||
stacklevel=2,
|
||||
)
|
||||
return tokenizer.decode(tokens[:max_tokens])
|
||||
|
||||
|
||||
def embed_text(text: str) -> np.ndarray:
|
||||
"""Truncate text to fit and embed it. Drop-in replacement for
|
||||
``config.embedding_model_instance.embed(text)`` that never exceeds the
|
||||
model's context window."""
|
||||
return config.embedding_model_instance.embed(truncate_for_embedding(text))
|
||||
|
||||
|
||||
def get_model_context_window(model_name: str) -> int:
|
||||
"""Get the total context window size for a model (input + output tokens)."""
|
||||
|
|
@ -146,7 +194,7 @@ async def generate_document_summary(
|
|||
else:
|
||||
enhanced_summary_content = summary_content
|
||||
|
||||
summary_embedding = config.embedding_model_instance.embed(enhanced_summary_content)
|
||||
summary_embedding = embed_text(enhanced_summary_content)
|
||||
|
||||
return enhanced_summary_content, summary_embedding
|
||||
|
||||
|
|
@ -164,7 +212,7 @@ async def create_document_chunks(content: str) -> list[Chunk]:
|
|||
return [
|
||||
Chunk(
|
||||
content=chunk.text,
|
||||
embedding=config.embedding_model_instance.embed(chunk.text),
|
||||
embedding=embed_text(chunk.text),
|
||||
)
|
||||
for chunk in config.chunker_instance.chunk(content)
|
||||
]
|
||||
|
|
|
|||
|
|
@ -29,4 +29,7 @@ if __name__ == "__main__":
|
|||
config = uvicorn.Config(**config_kwargs)
|
||||
server = uvicorn.Server(config)
|
||||
|
||||
server.run()
|
||||
if sys.platform == "win32":
|
||||
asyncio.run(server.serve(), loop_factory=asyncio.SelectorEventLoop)
|
||||
else:
|
||||
server.run()
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ dependencies = [
|
|||
"kokoro>=0.9.4",
|
||||
"linkup-sdk>=0.2.4",
|
||||
"llama-cloud-services>=0.6.25",
|
||||
"Markdown>=3.7",
|
||||
"markdownify>=0.14.1",
|
||||
"notion-client>=2.3.0",
|
||||
"numpy>=1.24.0",
|
||||
|
|
@ -65,11 +66,16 @@ dependencies = [
|
|||
"pypandoc_binary>=1.16.2",
|
||||
"typst>=0.14.0",
|
||||
"deepagents>=0.4.3",
|
||||
"langchain-daytona>=0.0.2",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"ruff>=0.12.5",
|
||||
"pytest>=9.0.2",
|
||||
"pytest-asyncio>=1.3.0",
|
||||
"pytest-mock>=3.14",
|
||||
"httpx>=0.28.1",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
|
|
@ -157,10 +163,27 @@ line-ending = "auto"
|
|||
|
||||
[tool.ruff.lint.isort]
|
||||
# Group imports by type
|
||||
known-first-party = ["app"]
|
||||
known-first-party = ["app", "tests"]
|
||||
force-single-line = false
|
||||
combine-as-imports = true
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "session"
|
||||
asyncio_default_test_loop_scope = "session"
|
||||
testpaths = ["tests"]
|
||||
python_files = ["test_*.py"]
|
||||
python_classes = ["Test*"]
|
||||
python_functions = ["test_*"]
|
||||
addopts = "-v --tb=short -x --strict-markers -ra --durations=5"
|
||||
markers = [
|
||||
"unit: pure logic tests, no DB or external services",
|
||||
"integration: tests that require a real PostgreSQL database"
|
||||
]
|
||||
filterwarnings = [
|
||||
"ignore::UserWarning:chonkie",
|
||||
]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["."]
|
||||
include = ["app*", "alembic*"]
|
||||
|
|
|
|||
0
surfsense_backend/tests/__init__.py
Normal file
0
surfsense_backend/tests/__init__.py
Normal file
61
surfsense_backend/tests/conftest.py
Normal file
61
surfsense_backend/tests/conftest.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
"""Root conftest — shared fixtures available to all test modules."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
_DEFAULT_TEST_DB = (
|
||||
"postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense_test"
|
||||
)
|
||||
TEST_DATABASE_URL = os.environ.get("TEST_DATABASE_URL", _DEFAULT_TEST_DB)
|
||||
|
||||
# Force the app to use the test database regardless of any pre-existing
|
||||
# DATABASE_URL in the environment (e.g. from .env or shell profile).
|
||||
os.environ["DATABASE_URL"] = TEST_DATABASE_URL
|
||||
|
||||
import pytest # noqa: E402
|
||||
|
||||
from app.db import DocumentType # noqa: E402
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument # noqa: E402
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit test fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_user_id() -> str:
|
||||
return "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_search_space_id() -> int:
|
||||
return 1
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_connector_id() -> int:
|
||||
return 42
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_connector_document():
|
||||
"""
|
||||
Generic factory for unit tests. Overridden in tests/integration/conftest.py
|
||||
with real DB-backed IDs for integration tests.
|
||||
"""
|
||||
|
||||
def _make(**overrides):
|
||||
defaults = {
|
||||
"title": "Test Document",
|
||||
"source_markdown": "## Heading\n\nSome content.",
|
||||
"unique_id": "test-id-001",
|
||||
"document_type": DocumentType.CLICKUP_CONNECTOR,
|
||||
"search_space_id": 1,
|
||||
"connector_id": 1,
|
||||
"created_by_id": "00000000-0000-0000-0000-000000000001",
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return ConnectorDocument(**defaults)
|
||||
|
||||
return _make
|
||||
0
surfsense_backend/tests/fixtures/empty.pdf
vendored
Normal file
0
surfsense_backend/tests/fixtures/empty.pdf
vendored
Normal file
51
surfsense_backend/tests/fixtures/sample.md
vendored
Normal file
51
surfsense_backend/tests/fixtures/sample.md
vendored
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
# SurfSense Test Document
|
||||
|
||||
## Overview
|
||||
|
||||
This is a **sample markdown document** used for end-to-end testing of the manual
|
||||
document upload pipeline. It includes various markdown formatting elements.
|
||||
|
||||
## Key Features
|
||||
|
||||
- Document upload and processing
|
||||
- Automatic chunking of content
|
||||
- Embedding generation for semantic search
|
||||
- Real-time status tracking via ElectricSQL
|
||||
|
||||
## Technical Architecture
|
||||
|
||||
### Backend Stack
|
||||
|
||||
The SurfSense backend is built with:
|
||||
|
||||
1. **FastAPI** for the REST API
|
||||
2. **PostgreSQL** with pgvector for vector storage
|
||||
3. **Celery** with Redis for background task processing
|
||||
4. **Docling/Unstructured** for document parsing (ETL)
|
||||
|
||||
### Processing Pipeline
|
||||
|
||||
Documents go through a multi-stage pipeline:
|
||||
|
||||
| Stage | Description |
|
||||
|-------|-------------|
|
||||
| Upload | File received via API endpoint |
|
||||
| Parsing | Content extracted using ETL service |
|
||||
| Chunking | Text split into semantic chunks |
|
||||
| Embedding | Vector representations generated |
|
||||
| Storage | Chunks stored with embeddings in pgvector |
|
||||
|
||||
## Code Example
|
||||
|
||||
```python
|
||||
async def process_document(file_path: str) -> Document:
|
||||
content = extract_content(file_path)
|
||||
chunks = create_chunks(content)
|
||||
embeddings = generate_embeddings(chunks)
|
||||
return store_document(chunks, embeddings)
|
||||
```
|
||||
|
||||
## Conclusion
|
||||
|
||||
This document serves as a test fixture to validate the complete document processing
|
||||
pipeline from upload through to chunk creation and embedding storage.
|
||||
BIN
surfsense_backend/tests/fixtures/sample.pdf
vendored
Normal file
BIN
surfsense_backend/tests/fixtures/sample.pdf
vendored
Normal file
Binary file not shown.
34
surfsense_backend/tests/fixtures/sample.txt
vendored
Normal file
34
surfsense_backend/tests/fixtures/sample.txt
vendored
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
SurfSense Document Upload Test
|
||||
|
||||
This is a sample text document used for end-to-end testing of the manual document
|
||||
upload pipeline in SurfSense. The document contains multiple paragraphs to ensure
|
||||
that the chunking system has enough content to work with.
|
||||
|
||||
Artificial Intelligence and Machine Learning
|
||||
|
||||
Artificial intelligence (AI) is a broad field of computer science concerned with
|
||||
building smart machines capable of performing tasks that typically require human
|
||||
intelligence. Machine learning is a subset of AI that enables systems to learn and
|
||||
improve from experience without being explicitly programmed.
|
||||
|
||||
Natural Language Processing
|
||||
|
||||
Natural language processing (NLP) is a subfield of linguistics, computer science,
|
||||
and artificial intelligence concerned with the interactions between computers and
|
||||
human language. Key applications include machine translation, sentiment analysis,
|
||||
text summarization, and question answering systems.
|
||||
|
||||
Vector Databases and Semantic Search
|
||||
|
||||
Vector databases store data as high-dimensional vectors, enabling efficient
|
||||
similarity search operations. When combined with embedding models, they power
|
||||
semantic search systems that understand the meaning behind queries rather than
|
||||
relying on exact keyword matches. This technology is fundamental to modern
|
||||
retrieval-augmented generation (RAG) systems.
|
||||
|
||||
Document Processing Pipelines
|
||||
|
||||
Modern document processing pipelines involve several stages: extraction, transformation,
|
||||
chunking, embedding generation, and storage. Each stage plays a critical role in
|
||||
converting raw documents into searchable, structured knowledge that can be retrieved
|
||||
and used by AI systems for accurate information retrieval and generation.
|
||||
0
surfsense_backend/tests/integration/__init__.py
Normal file
0
surfsense_backend/tests/integration/__init__.py
Normal file
168
surfsense_backend/tests/integration/conftest.py
Normal file
168
surfsense_backend/tests/integration/conftest.py
Normal file
|
|
@ -0,0 +1,168 @@
|
|||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.db import (
|
||||
Base,
|
||||
DocumentType,
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
SearchSpace,
|
||||
User,
|
||||
)
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
from tests.conftest import TEST_DATABASE_URL
|
||||
|
||||
_EMBEDDING_DIM = app_config.embedding_model_instance.dimension
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def async_engine():
|
||||
engine = create_async_engine(
|
||||
TEST_DATABASE_URL,
|
||||
poolclass=NullPool,
|
||||
echo=False,
|
||||
# Required for asyncpg + savepoints: disables prepared statement cache
|
||||
# to prevent "another operation is in progress" errors during savepoint rollbacks.
|
||||
connect_args={"prepared_statement_cache_size": 0},
|
||||
)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
# drop_all fails on circular FKs (new_chat_threads ↔ public_chat_snapshots).
|
||||
# DROP SCHEMA CASCADE handles this without needing topological sort.
|
||||
async with engine.begin() as conn:
|
||||
await conn.execute(text("DROP SCHEMA public CASCADE"))
|
||||
await conn.execute(text("CREATE SCHEMA public"))
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_session(async_engine) -> AsyncSession:
|
||||
# Bind the session to a connection that holds an outer transaction.
|
||||
# join_transaction_mode="create_savepoint" makes session.commit() release
|
||||
# a SAVEPOINT instead of committing the outer transaction, so the final
|
||||
# transaction.rollback() undoes everything — including commits made by the
|
||||
# service under test — leaving the DB clean for the next test.
|
||||
async with async_engine.connect() as conn:
|
||||
transaction = await conn.begin()
|
||||
async with AsyncSession(
|
||||
bind=conn,
|
||||
expire_on_commit=False,
|
||||
join_transaction_mode="create_savepoint",
|
||||
) as session:
|
||||
yield session
|
||||
await transaction.rollback()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_user(db_session: AsyncSession) -> User:
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="test@surfsense.net",
|
||||
hashed_password="hashed",
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
)
|
||||
db_session.add(user)
|
||||
await db_session.flush()
|
||||
return user
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_connector(
|
||||
db_session: AsyncSession, db_user: User, db_search_space: "SearchSpace"
|
||||
) -> SearchSourceConnector:
|
||||
connector = SearchSourceConnector(
|
||||
name="Test Connector",
|
||||
connector_type=SearchSourceConnectorType.CLICKUP_CONNECTOR,
|
||||
config={},
|
||||
search_space_id=db_search_space.id,
|
||||
user_id=db_user.id,
|
||||
)
|
||||
db_session.add(connector)
|
||||
await db_session.flush()
|
||||
return connector
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_search_space(db_session: AsyncSession, db_user: User) -> SearchSpace:
|
||||
space = SearchSpace(
|
||||
name="Test Space",
|
||||
user_id=db_user.id,
|
||||
)
|
||||
db_session.add(space)
|
||||
await db_session.flush()
|
||||
return space
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_summarize(monkeypatch) -> AsyncMock:
|
||||
mock = AsyncMock(return_value="Mocked summary.")
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.indexing_pipeline_service.summarize_document",
|
||||
mock,
|
||||
)
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_summarize_raises(monkeypatch) -> AsyncMock:
|
||||
mock = AsyncMock(side_effect=RuntimeError("LLM unavailable"))
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.indexing_pipeline_service.summarize_document",
|
||||
mock,
|
||||
)
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_embed_text(monkeypatch) -> MagicMock:
|
||||
mock = MagicMock(return_value=[0.1] * _EMBEDDING_DIM)
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.indexing_pipeline_service.embed_text",
|
||||
mock,
|
||||
)
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_chunk_text(monkeypatch) -> MagicMock:
|
||||
mock = MagicMock(return_value=["Test chunk content."])
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.indexing_pipeline_service.chunk_text",
|
||||
mock,
|
||||
)
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_connector_document(db_connector, db_user):
|
||||
"""Integration-scoped override: uses real DB connector and user IDs."""
|
||||
|
||||
def _make(**overrides):
|
||||
defaults = {
|
||||
"title": "Test Document",
|
||||
"source_markdown": "## Heading\n\nSome content.",
|
||||
"unique_id": "test-id-001",
|
||||
"document_type": DocumentType.CLICKUP_CONNECTOR,
|
||||
"search_space_id": db_connector.search_space_id,
|
||||
"connector_id": db_connector.id,
|
||||
"created_by_id": str(db_user.id),
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return ConnectorDocument(**defaults)
|
||||
|
||||
return _make
|
||||
289
surfsense_backend/tests/integration/document_upload/conftest.py
Normal file
289
surfsense_backend/tests/integration/document_upload/conftest.py
Normal file
|
|
@ -0,0 +1,289 @@
|
|||
"""Integration conftest — runs the FastAPI app in-process via ASGITransport.
|
||||
|
||||
Prerequisites: PostgreSQL + pgvector only.
|
||||
|
||||
External system boundaries are mocked:
|
||||
- LLM summarization, text embedding, text chunking (external APIs)
|
||||
- Redis heartbeat (external infrastructure)
|
||||
- Task dispatch is swapped via DI (InlineTaskDispatcher)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
from collections.abc import AsyncGenerator
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import asyncpg
|
||||
import httpx
|
||||
import pytest
|
||||
from httpx import ASGITransport
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from app.app import app
|
||||
from app.config import config as app_config
|
||||
from app.db import Base
|
||||
from app.services.task_dispatcher import get_task_dispatcher
|
||||
from tests.integration.conftest import TEST_DATABASE_URL
|
||||
from tests.utils.helpers import (
|
||||
TEST_EMAIL,
|
||||
auth_headers,
|
||||
delete_document,
|
||||
get_auth_token,
|
||||
get_search_space_id,
|
||||
)
|
||||
|
||||
_EMBEDDING_DIM = app_config.embedding_model_instance.dimension
|
||||
_ASYNCPG_URL = TEST_DATABASE_URL.replace("postgresql+asyncpg://", "postgresql://")
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Inline task dispatcher (replaces Celery via DI — not a mock)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class InlineTaskDispatcher:
|
||||
"""Processes files synchronously in the calling coroutine.
|
||||
|
||||
Swapped in via FastAPI dependency_overrides so the upload endpoint
|
||||
processes documents inline instead of dispatching to Celery.
|
||||
|
||||
Exceptions are caught to match Celery's fire-and-forget semantics —
|
||||
the processing function already marks documents as failed internally.
|
||||
"""
|
||||
|
||||
async def dispatch_file_processing(
|
||||
self,
|
||||
*,
|
||||
document_id: int,
|
||||
temp_path: str,
|
||||
filename: str,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
should_summarize: bool = False,
|
||||
) -> None:
|
||||
from app.tasks.celery_tasks.document_tasks import (
|
||||
_process_file_with_document,
|
||||
)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
await _process_file_with_document(
|
||||
document_id,
|
||||
temp_path,
|
||||
filename,
|
||||
search_space_id,
|
||||
user_id,
|
||||
should_summarize=should_summarize,
|
||||
)
|
||||
|
||||
|
||||
app.dependency_overrides[get_task_dispatcher] = lambda: InlineTaskDispatcher()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Database setup (ASGITransport skips the app lifespan)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def _ensure_tables():
|
||||
"""Create DB tables and extensions once per session."""
|
||||
engine = create_async_engine(TEST_DATABASE_URL, poolclass=NullPool)
|
||||
async with engine.begin() as conn:
|
||||
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
|
||||
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm"))
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auth & search space (session-scoped, via the in-process app)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def auth_token(_ensure_tables) -> str:
|
||||
"""Authenticate once per session, registering the user if needed."""
|
||||
async with httpx.AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test", timeout=30.0
|
||||
) as c:
|
||||
return await get_auth_token(c)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def search_space_id(auth_token: str) -> int:
|
||||
"""Discover the first search space belonging to the test user."""
|
||||
async with httpx.AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test", timeout=30.0
|
||||
) as c:
|
||||
return await get_search_space_id(c, auth_token)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def headers(auth_token: str) -> dict[str, str]:
|
||||
return auth_headers(auth_token)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-test HTTP client & cleanup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client() -> AsyncGenerator[httpx.AsyncClient]:
|
||||
"""Per-test async HTTP client using ASGITransport (no running server)."""
|
||||
async with httpx.AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test", timeout=180.0
|
||||
) as c:
|
||||
yield c
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cleanup_doc_ids() -> list[int]:
|
||||
"""Accumulator for document IDs that should be deleted after the test."""
|
||||
return []
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
async def _purge_test_search_space(search_space_id: int):
|
||||
"""Delete stale documents from previous runs before the session starts."""
|
||||
conn = await asyncpg.connect(_ASYNCPG_URL)
|
||||
try:
|
||||
result = await conn.execute(
|
||||
"DELETE FROM documents WHERE search_space_id = $1",
|
||||
search_space_id,
|
||||
)
|
||||
deleted = int(result.split()[-1])
|
||||
if deleted:
|
||||
print(
|
||||
f"\n[purge] Deleted {deleted} stale document(s) "
|
||||
f"from search space {search_space_id}"
|
||||
)
|
||||
finally:
|
||||
await conn.close()
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def _cleanup_documents(
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
cleanup_doc_ids: list[int],
|
||||
):
|
||||
"""Delete test documents after every test (API first, DB fallback)."""
|
||||
yield
|
||||
|
||||
remaining_ids: list[int] = []
|
||||
for doc_id in cleanup_doc_ids:
|
||||
try:
|
||||
resp = await delete_document(client, headers, doc_id)
|
||||
if resp.status_code == 409:
|
||||
remaining_ids.append(doc_id)
|
||||
except Exception:
|
||||
remaining_ids.append(doc_id)
|
||||
|
||||
if remaining_ids:
|
||||
conn = await asyncpg.connect(_ASYNCPG_URL)
|
||||
try:
|
||||
await conn.execute(
|
||||
"DELETE FROM documents WHERE id = ANY($1::int[])",
|
||||
remaining_ids,
|
||||
)
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Page-limit helpers (direct DB for setup, API for verification)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _get_user_page_usage(email: str) -> tuple[int, int]:
|
||||
conn = await asyncpg.connect(_ASYNCPG_URL)
|
||||
try:
|
||||
row = await conn.fetchrow(
|
||||
'SELECT pages_used, pages_limit FROM "user" WHERE email = $1',
|
||||
email,
|
||||
)
|
||||
assert row is not None, f"User {email!r} not found in database"
|
||||
return row["pages_used"], row["pages_limit"]
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
|
||||
async def _set_user_page_limits(
|
||||
email: str, *, pages_used: int, pages_limit: int
|
||||
) -> None:
|
||||
conn = await asyncpg.connect(_ASYNCPG_URL)
|
||||
try:
|
||||
await conn.execute(
|
||||
'UPDATE "user" SET pages_used = $1, pages_limit = $2 WHERE email = $3',
|
||||
pages_used,
|
||||
pages_limit,
|
||||
email,
|
||||
)
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def page_limits():
|
||||
"""Manipulate the test user's page limits (direct DB for setup only).
|
||||
|
||||
Automatically restores original values after each test.
|
||||
"""
|
||||
|
||||
class _PageLimits:
|
||||
async def set(self, *, pages_used: int, pages_limit: int) -> None:
|
||||
await _set_user_page_limits(
|
||||
TEST_EMAIL, pages_used=pages_used, pages_limit=pages_limit
|
||||
)
|
||||
|
||||
original = await _get_user_page_usage(TEST_EMAIL)
|
||||
yield _PageLimits()
|
||||
await _set_user_page_limits(
|
||||
TEST_EMAIL, pages_used=original[0], pages_limit=original[1]
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock external system boundaries
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_external_apis(monkeypatch):
|
||||
"""Mock LLM, embedding, and chunking — these are external API boundaries."""
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.indexing_pipeline_service.summarize_document",
|
||||
AsyncMock(return_value="Mocked summary."),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.indexing_pipeline_service.embed_text",
|
||||
MagicMock(return_value=[0.1] * _EMBEDDING_DIM),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.indexing_pipeline_service.chunk_text",
|
||||
MagicMock(return_value=["Test chunk content."]),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_redis_heartbeat(monkeypatch):
|
||||
"""Mock Redis heartbeat — Redis is an external infrastructure boundary."""
|
||||
monkeypatch.setattr(
|
||||
"app.tasks.celery_tasks.document_tasks._start_heartbeat",
|
||||
lambda notification_id: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.tasks.celery_tasks.document_tasks._stop_heartbeat",
|
||||
lambda notification_id: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.tasks.celery_tasks.document_tasks._run_heartbeat_loop",
|
||||
AsyncMock(),
|
||||
)
|
||||
|
|
@ -0,0 +1,337 @@
|
|||
"""
|
||||
Integration tests for the document upload HTTP API.
|
||||
|
||||
Covers the API contract, auth, duplicate detection, and error handling.
|
||||
Pipeline internals are tested in the ``indexing_pipeline`` suite.
|
||||
|
||||
Requires PostgreSQL + pgvector.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from tests.utils.helpers import (
|
||||
FIXTURES_DIR,
|
||||
poll_document_status,
|
||||
upload_file,
|
||||
upload_multiple_files,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Upload smoke tests (one per distinct code-path: direct-read & ETL)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTxtFileUpload:
|
||||
"""Upload a plain-text file (direct-read path) via the HTTP API."""
|
||||
|
||||
async def test_upload_txt_returns_document_id(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
):
|
||||
resp = await upload_file(
|
||||
client, headers, "sample.txt", search_space_id=search_space_id
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
body = resp.json()
|
||||
assert body["pending_files"] >= 1
|
||||
assert len(body["document_ids"]) >= 1
|
||||
cleanup_doc_ids.extend(body["document_ids"])
|
||||
|
||||
async def test_txt_processing_reaches_ready(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
):
|
||||
resp = await upload_file(
|
||||
client, headers, "sample.txt", search_space_id=search_space_id
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
doc_ids = resp.json()["document_ids"]
|
||||
cleanup_doc_ids.extend(doc_ids)
|
||||
|
||||
statuses = await poll_document_status(
|
||||
client, headers, doc_ids, search_space_id=search_space_id
|
||||
)
|
||||
for did in doc_ids:
|
||||
assert statuses[did]["status"]["state"] == "ready"
|
||||
|
||||
|
||||
class TestPdfFileUpload:
|
||||
"""Upload a PDF (ETL extraction path) via the HTTP API."""
|
||||
|
||||
async def test_pdf_processing_reaches_ready(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
):
|
||||
resp = await upload_file(
|
||||
client, headers, "sample.pdf", search_space_id=search_space_id
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
doc_ids = resp.json()["document_ids"]
|
||||
cleanup_doc_ids.extend(doc_ids)
|
||||
|
||||
statuses = await poll_document_status(
|
||||
client, headers, doc_ids, search_space_id=search_space_id, timeout=300.0
|
||||
)
|
||||
for did in doc_ids:
|
||||
assert statuses[did]["status"]["state"] == "ready"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test D: Upload multiple files in a single request
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMultiFileUpload:
|
||||
"""Upload several files at once and verify the API response contract."""
|
||||
|
||||
async def test_multi_upload_returns_all_ids(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
):
|
||||
resp = await upload_multiple_files(
|
||||
client,
|
||||
headers,
|
||||
["sample.txt", "sample.md"],
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
body = resp.json()
|
||||
assert body["pending_files"] == 2
|
||||
assert len(body["document_ids"]) == 2
|
||||
cleanup_doc_ids.extend(body["document_ids"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test E: Duplicate file upload (same file uploaded twice)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDuplicateFileUpload:
|
||||
"""
|
||||
Uploading the exact same file a second time should be detected as a
|
||||
duplicate via ``unique_identifier_hash``.
|
||||
"""
|
||||
|
||||
async def test_duplicate_file_is_skipped(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
):
|
||||
resp1 = await upload_file(
|
||||
client, headers, "sample.txt", search_space_id=search_space_id
|
||||
)
|
||||
assert resp1.status_code == 200
|
||||
first_ids = resp1.json()["document_ids"]
|
||||
cleanup_doc_ids.extend(first_ids)
|
||||
|
||||
await poll_document_status(
|
||||
client, headers, first_ids, search_space_id=search_space_id
|
||||
)
|
||||
|
||||
resp2 = await upload_file(
|
||||
client, headers, "sample.txt", search_space_id=search_space_id
|
||||
)
|
||||
assert resp2.status_code == 200
|
||||
|
||||
body2 = resp2.json()
|
||||
assert body2["skipped_duplicates"] >= 1
|
||||
assert len(body2["duplicate_document_ids"]) >= 1
|
||||
cleanup_doc_ids.extend(body2.get("document_ids", []))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test F: Duplicate content detection (different name, same content)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDuplicateContentDetection:
|
||||
"""
|
||||
Uploading a file with a different name but identical content should be
|
||||
detected as duplicate content via ``content_hash``.
|
||||
"""
|
||||
|
||||
async def test_same_content_different_name_detected(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
tmp_path: Path,
|
||||
):
|
||||
resp1 = await upload_file(
|
||||
client, headers, "sample.txt", search_space_id=search_space_id
|
||||
)
|
||||
assert resp1.status_code == 200
|
||||
first_ids = resp1.json()["document_ids"]
|
||||
cleanup_doc_ids.extend(first_ids)
|
||||
await poll_document_status(
|
||||
client, headers, first_ids, search_space_id=search_space_id
|
||||
)
|
||||
|
||||
src = FIXTURES_DIR / "sample.txt"
|
||||
dest = tmp_path / "renamed_sample.txt"
|
||||
shutil.copy2(src, dest)
|
||||
|
||||
with open(dest, "rb") as f:
|
||||
resp2 = await client.post(
|
||||
"/api/v1/documents/fileupload",
|
||||
headers=headers,
|
||||
files={"files": ("renamed_sample.txt", f)},
|
||||
data={"search_space_id": str(search_space_id)},
|
||||
)
|
||||
assert resp2.status_code == 200
|
||||
second_ids = resp2.json()["document_ids"]
|
||||
cleanup_doc_ids.extend(second_ids)
|
||||
assert second_ids, (
|
||||
"Expected at least one document id for renamed duplicate content upload"
|
||||
)
|
||||
|
||||
statuses = await poll_document_status(
|
||||
client, headers, second_ids, search_space_id=search_space_id
|
||||
)
|
||||
for did in second_ids:
|
||||
assert statuses[did]["status"]["state"] == "failed"
|
||||
assert "duplicate" in statuses[did]["status"].get("reason", "").lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test G: Empty / corrupt file handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEmptyFileUpload:
|
||||
"""An empty file should be processed but ultimately fail gracefully."""
|
||||
|
||||
async def test_empty_pdf_fails(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
):
|
||||
resp = await upload_file(
|
||||
client, headers, "empty.pdf", search_space_id=search_space_id
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
doc_ids = resp.json()["document_ids"]
|
||||
cleanup_doc_ids.extend(doc_ids)
|
||||
assert doc_ids, "Expected at least one document id for empty PDF upload"
|
||||
|
||||
statuses = await poll_document_status(
|
||||
client, headers, doc_ids, search_space_id=search_space_id, timeout=120.0
|
||||
)
|
||||
for did in doc_ids:
|
||||
assert statuses[did]["status"]["state"] == "failed"
|
||||
assert statuses[did]["status"].get("reason"), (
|
||||
"Failed document should include a reason"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test H: Upload without authentication
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestUnauthenticatedUpload:
|
||||
"""Requests without a valid JWT should be rejected."""
|
||||
|
||||
async def test_upload_without_auth_returns_401(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
search_space_id: int,
|
||||
):
|
||||
file_path = FIXTURES_DIR / "sample.txt"
|
||||
with open(file_path, "rb") as f:
|
||||
resp = await client.post(
|
||||
"/api/v1/documents/fileupload",
|
||||
files={"files": ("sample.txt", f)},
|
||||
data={"search_space_id": str(search_space_id)},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test I: Upload with no files attached
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNoFilesUpload:
|
||||
"""Submitting the form with zero files should return a validation error."""
|
||||
|
||||
async def test_no_files_returns_error(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
):
|
||||
resp = await client.post(
|
||||
"/api/v1/documents/fileupload",
|
||||
headers=headers,
|
||||
data={"search_space_id": str(search_space_id)},
|
||||
)
|
||||
assert resp.status_code in {400, 422}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test K: Searchability after upload
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDocumentSearchability:
|
||||
"""After upload reaches ready, the document must appear in the title search."""
|
||||
|
||||
async def test_uploaded_document_appears_in_search(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
):
|
||||
resp = await upload_file(
|
||||
client, headers, "sample.txt", search_space_id=search_space_id
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
doc_ids = resp.json()["document_ids"]
|
||||
cleanup_doc_ids.extend(doc_ids)
|
||||
|
||||
await poll_document_status(
|
||||
client, headers, doc_ids, search_space_id=search_space_id
|
||||
)
|
||||
|
||||
search_resp = await client.get(
|
||||
"/api/v1/documents/search",
|
||||
headers=headers,
|
||||
params={"title": "sample", "search_space_id": search_space_id},
|
||||
)
|
||||
assert search_resp.status_code == 200
|
||||
|
||||
result_ids = [d["id"] for d in search_resp.json()["items"]]
|
||||
assert doc_ids[0] in result_ids, (
|
||||
f"Uploaded document {doc_ids[0]} not found in search results: {result_ids}"
|
||||
)
|
||||
|
|
@ -0,0 +1,332 @@
|
|||
"""
|
||||
Integration tests for page-limit enforcement during document upload.
|
||||
|
||||
These tests manipulate the test user's ``pages_used`` / ``pages_limit``
|
||||
columns directly in the database (setup only) and then exercise the upload
|
||||
pipeline to verify that:
|
||||
|
||||
- Uploads are rejected *before* ETL when the limit is exhausted.
|
||||
- ``pages_used`` increases after a successful upload (verified via API).
|
||||
- A ``page_limit_exceeded`` notification is created on rejection.
|
||||
- ``pages_used`` is not modified when a document fails processing.
|
||||
|
||||
All tests reuse the existing small fixtures (``sample.pdf``, ``sample.txt``)
|
||||
so no additional processing time is introduced.
|
||||
|
||||
Prerequisites:
|
||||
- PostgreSQL + pgvector
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from tests.utils.helpers import (
|
||||
get_notifications,
|
||||
poll_document_status,
|
||||
upload_file,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper: read pages_used through the public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _get_pages_used(client: httpx.AsyncClient, headers: dict[str, str]) -> int:
|
||||
"""Fetch the current user's pages_used via the /users/me API."""
|
||||
resp = await client.get("/users/me", headers=headers)
|
||||
assert resp.status_code == 200, (
|
||||
f"GET /users/me failed ({resp.status_code}): {resp.text}"
|
||||
)
|
||||
return resp.json()["pages_used"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test A: Successful upload increments pages_used
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPageUsageIncrementsOnSuccess:
|
||||
"""After a successful PDF upload the user's ``pages_used`` must grow."""
|
||||
|
||||
async def test_pages_used_increases_after_pdf_upload(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
page_limits,
|
||||
):
|
||||
await page_limits.set(pages_used=0, pages_limit=1000)
|
||||
|
||||
resp = await upload_file(
|
||||
client, headers, "sample.pdf", search_space_id=search_space_id
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
doc_ids = resp.json()["document_ids"]
|
||||
cleanup_doc_ids.extend(doc_ids)
|
||||
|
||||
statuses = await poll_document_status(
|
||||
client, headers, doc_ids, search_space_id=search_space_id, timeout=300.0
|
||||
)
|
||||
for did in doc_ids:
|
||||
assert statuses[did]["status"]["state"] == "ready"
|
||||
|
||||
used = await _get_pages_used(client, headers)
|
||||
assert used > 0, "pages_used should have increased after successful processing"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test B: Upload rejected when page limit is fully exhausted
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestUploadRejectedWhenLimitExhausted:
|
||||
"""
|
||||
When ``pages_used == pages_limit`` (zero remaining) the document
|
||||
should reach ``failed`` status with a page-limit reason.
|
||||
"""
|
||||
|
||||
async def test_pdf_fails_when_no_pages_remaining(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
page_limits,
|
||||
):
|
||||
await page_limits.set(pages_used=100, pages_limit=100)
|
||||
|
||||
resp = await upload_file(
|
||||
client, headers, "sample.pdf", search_space_id=search_space_id
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
doc_ids = resp.json()["document_ids"]
|
||||
cleanup_doc_ids.extend(doc_ids)
|
||||
|
||||
statuses = await poll_document_status(
|
||||
client, headers, doc_ids, search_space_id=search_space_id, timeout=300.0
|
||||
)
|
||||
for did in doc_ids:
|
||||
assert statuses[did]["status"]["state"] == "failed"
|
||||
reason = statuses[did]["status"].get("reason", "").lower()
|
||||
assert "page limit" in reason, (
|
||||
f"Expected 'page limit' in failure reason, got: {reason!r}"
|
||||
)
|
||||
|
||||
async def test_pages_used_unchanged_after_limit_rejection(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
page_limits,
|
||||
):
|
||||
await page_limits.set(pages_used=50, pages_limit=50)
|
||||
|
||||
resp = await upload_file(
|
||||
client, headers, "sample.pdf", search_space_id=search_space_id
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
doc_ids = resp.json()["document_ids"]
|
||||
cleanup_doc_ids.extend(doc_ids)
|
||||
|
||||
await poll_document_status(
|
||||
client, headers, doc_ids, search_space_id=search_space_id, timeout=300.0
|
||||
)
|
||||
|
||||
used = await _get_pages_used(client, headers)
|
||||
assert used == 50, (
|
||||
f"pages_used should remain 50 after rejected upload, got {used}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test C: Page-limit notification is created on rejection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPageLimitNotification:
|
||||
"""A ``page_limit_exceeded`` notification must be created when upload
|
||||
is rejected due to the limit."""
|
||||
|
||||
async def test_page_limit_exceeded_notification_created(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
page_limits,
|
||||
):
|
||||
await page_limits.set(pages_used=100, pages_limit=100)
|
||||
|
||||
resp = await upload_file(
|
||||
client, headers, "sample.pdf", search_space_id=search_space_id
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
doc_ids = resp.json()["document_ids"]
|
||||
cleanup_doc_ids.extend(doc_ids)
|
||||
|
||||
await poll_document_status(
|
||||
client, headers, doc_ids, search_space_id=search_space_id, timeout=300.0
|
||||
)
|
||||
|
||||
notifications = await get_notifications(
|
||||
client,
|
||||
headers,
|
||||
type_filter="page_limit_exceeded",
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
assert len(notifications) >= 1, (
|
||||
"Expected at least one page_limit_exceeded notification"
|
||||
)
|
||||
|
||||
latest = notifications[0]
|
||||
assert (
|
||||
"page limit" in latest["title"].lower()
|
||||
or "page limit" in latest["message"].lower()
|
||||
), (
|
||||
f"Notification should mention page limit: title={latest['title']!r}, "
|
||||
f"message={latest['message']!r}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test D: Successful upload creates a completed document_processing notification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDocumentProcessingNotification:
|
||||
"""A ``document_processing`` notification with ``completed`` status must
|
||||
exist after a successful upload."""
|
||||
|
||||
async def test_processing_completed_notification_exists(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
page_limits,
|
||||
):
|
||||
await page_limits.set(pages_used=0, pages_limit=1000)
|
||||
|
||||
resp = await upload_file(
|
||||
client, headers, "sample.txt", search_space_id=search_space_id
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
doc_ids = resp.json()["document_ids"]
|
||||
cleanup_doc_ids.extend(doc_ids)
|
||||
|
||||
await poll_document_status(
|
||||
client, headers, doc_ids, search_space_id=search_space_id
|
||||
)
|
||||
|
||||
notifications = await get_notifications(
|
||||
client,
|
||||
headers,
|
||||
type_filter="document_processing",
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
completed = [
|
||||
n
|
||||
for n in notifications
|
||||
if n.get("metadata", {}).get("processing_stage") == "completed"
|
||||
]
|
||||
assert len(completed) >= 1, (
|
||||
"Expected at least one document_processing notification with 'completed' stage"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test E: pages_used unchanged when a document fails for non-limit reasons
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPagesUnchangedOnProcessingFailure:
|
||||
"""If a document fails during ETL (e.g. empty/corrupt file) rather than
|
||||
a page-limit rejection, ``pages_used`` should remain unchanged."""
|
||||
|
||||
async def test_pages_used_stable_on_etl_failure(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
page_limits,
|
||||
):
|
||||
await page_limits.set(pages_used=10, pages_limit=1000)
|
||||
|
||||
resp = await upload_file(
|
||||
client, headers, "empty.pdf", search_space_id=search_space_id
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
doc_ids = resp.json()["document_ids"]
|
||||
cleanup_doc_ids.extend(doc_ids)
|
||||
|
||||
if doc_ids:
|
||||
statuses = await poll_document_status(
|
||||
client, headers, doc_ids, search_space_id=search_space_id, timeout=120.0
|
||||
)
|
||||
for did in doc_ids:
|
||||
assert statuses[did]["status"]["state"] == "failed"
|
||||
|
||||
used = await _get_pages_used(client, headers)
|
||||
assert used == 10, f"pages_used should remain 10 after ETL failure, got {used}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test F: Second upload rejected after first consumes remaining quota
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSecondUploadExceedsLimit:
|
||||
"""Upload one PDF successfully, consuming the quota, then verify a
|
||||
second upload is rejected."""
|
||||
|
||||
async def test_second_upload_rejected_after_quota_consumed(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
page_limits,
|
||||
):
|
||||
await page_limits.set(pages_used=0, pages_limit=1)
|
||||
|
||||
resp1 = await upload_file(
|
||||
client, headers, "sample.pdf", search_space_id=search_space_id
|
||||
)
|
||||
assert resp1.status_code == 200
|
||||
first_ids = resp1.json()["document_ids"]
|
||||
cleanup_doc_ids.extend(first_ids)
|
||||
|
||||
statuses1 = await poll_document_status(
|
||||
client, headers, first_ids, search_space_id=search_space_id, timeout=300.0
|
||||
)
|
||||
for did in first_ids:
|
||||
assert statuses1[did]["status"]["state"] == "ready"
|
||||
|
||||
resp2 = await upload_file(
|
||||
client,
|
||||
headers,
|
||||
"sample.pdf",
|
||||
search_space_id=search_space_id,
|
||||
filename_override="sample_copy.pdf",
|
||||
)
|
||||
assert resp2.status_code == 200
|
||||
second_ids = resp2.json()["document_ids"]
|
||||
cleanup_doc_ids.extend(second_ids)
|
||||
|
||||
statuses2 = await poll_document_status(
|
||||
client, headers, second_ids, search_space_id=search_space_id, timeout=300.0
|
||||
)
|
||||
for did in second_ids:
|
||||
assert statuses2[did]["status"]["state"] == "failed"
|
||||
reason = statuses2[did]["status"].get("reason", "").lower()
|
||||
assert "page limit" in reason, (
|
||||
f"Expected 'page limit' in failure reason, got: {reason!r}"
|
||||
)
|
||||
|
|
@ -0,0 +1,145 @@
|
|||
"""
|
||||
Integration tests for backend file upload limit enforcement.
|
||||
|
||||
These tests verify that the API rejects uploads that exceed:
|
||||
- Max files per upload (10)
|
||||
- Max per-file size (50 MB)
|
||||
- Max total upload size (200 MB)
|
||||
|
||||
The limits mirror the frontend's DocumentUploadTab.tsx constants and are
|
||||
enforced server-side to protect against direct API calls.
|
||||
|
||||
Prerequisites:
|
||||
- PostgreSQL + pgvector
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test A: File count limit
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFileCountLimit:
|
||||
"""Uploading more than 10 files in a single request should be rejected."""
|
||||
|
||||
async def test_11_files_returns_413(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
):
|
||||
files = [
|
||||
("files", (f"file_{i}.txt", io.BytesIO(b"test content"), "text/plain"))
|
||||
for i in range(11)
|
||||
]
|
||||
resp = await client.post(
|
||||
"/api/v1/documents/fileupload",
|
||||
headers=headers,
|
||||
files=files,
|
||||
data={"search_space_id": str(search_space_id)},
|
||||
)
|
||||
assert resp.status_code == 413
|
||||
assert "too many files" in resp.json()["detail"].lower()
|
||||
|
||||
async def test_10_files_accepted(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
):
|
||||
files = [
|
||||
("files", (f"file_{i}.txt", io.BytesIO(b"test content"), "text/plain"))
|
||||
for i in range(10)
|
||||
]
|
||||
resp = await client.post(
|
||||
"/api/v1/documents/fileupload",
|
||||
headers=headers,
|
||||
files=files,
|
||||
data={"search_space_id": str(search_space_id)},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
cleanup_doc_ids.extend(resp.json().get("document_ids", []))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test B: Per-file size limit
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPerFileSizeLimit:
|
||||
"""A single file exceeding 50 MB should be rejected."""
|
||||
|
||||
async def test_oversized_file_returns_413(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
):
|
||||
oversized = io.BytesIO(b"\x00" * (50 * 1024 * 1024 + 1))
|
||||
resp = await client.post(
|
||||
"/api/v1/documents/fileupload",
|
||||
headers=headers,
|
||||
files=[("files", ("big.pdf", oversized, "application/pdf"))],
|
||||
data={"search_space_id": str(search_space_id)},
|
||||
)
|
||||
assert resp.status_code == 413
|
||||
assert "per-file limit" in resp.json()["detail"].lower()
|
||||
|
||||
async def test_file_at_limit_accepted(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
):
|
||||
at_limit = io.BytesIO(b"\x00" * (50 * 1024 * 1024))
|
||||
resp = await client.post(
|
||||
"/api/v1/documents/fileupload",
|
||||
headers=headers,
|
||||
files=[("files", ("exact50mb.txt", at_limit, "text/plain"))],
|
||||
data={"search_space_id": str(search_space_id)},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
cleanup_doc_ids.extend(resp.json().get("document_ids", []))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test C: Total upload size limit
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTotalSizeLimit:
|
||||
"""Multiple files whose combined size exceeds 200 MB should be rejected."""
|
||||
|
||||
async def test_total_size_over_200mb_returns_413(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
):
|
||||
chunk_size = 45 * 1024 * 1024 # 45 MB each
|
||||
files = [
|
||||
(
|
||||
"files",
|
||||
(f"chunk_{i}.txt", io.BytesIO(b"\x00" * chunk_size), "text/plain"),
|
||||
)
|
||||
for i in range(5) # 5 x 45 MB = 225 MB > 200 MB
|
||||
]
|
||||
resp = await client.post(
|
||||
"/api/v1/documents/fileupload",
|
||||
headers=headers,
|
||||
files=files,
|
||||
data={"search_space_id": str(search_space_id)},
|
||||
)
|
||||
assert resp.status_code == 413
|
||||
assert "total upload size" in resp.json()["detail"].lower()
|
||||
|
|
@ -0,0 +1,99 @@
|
|||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.db import Chunk, Document, DocumentStatus
|
||||
from app.indexing_pipeline.adapters.file_upload_adapter import index_uploaded_file
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize", "patched_embed_text", "patched_chunk_text"
|
||||
)
|
||||
async def test_sets_status_ready(db_session, db_search_space, db_user, mocker):
|
||||
"""Document status is READY after successful indexing."""
|
||||
await index_uploaded_file(
|
||||
markdown_content="## Hello\n\nSome content.",
|
||||
filename="test.pdf",
|
||||
etl_service="UNSTRUCTURED",
|
||||
search_space_id=db_search_space.id,
|
||||
user_id=str(db_user.id),
|
||||
session=db_session,
|
||||
llm=mocker.Mock(),
|
||||
)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.search_space_id == db_search_space.id)
|
||||
)
|
||||
document = result.scalars().first()
|
||||
|
||||
assert DocumentStatus.is_state(document.status, DocumentStatus.READY)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize", "patched_embed_text", "patched_chunk_text"
|
||||
)
|
||||
async def test_content_is_summary(db_session, db_search_space, db_user, mocker):
|
||||
"""Document content is set to the LLM-generated summary."""
|
||||
await index_uploaded_file(
|
||||
markdown_content="## Hello\n\nSome content.",
|
||||
filename="test.pdf",
|
||||
etl_service="UNSTRUCTURED",
|
||||
search_space_id=db_search_space.id,
|
||||
user_id=str(db_user.id),
|
||||
session=db_session,
|
||||
llm=mocker.Mock(),
|
||||
)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.search_space_id == db_search_space.id)
|
||||
)
|
||||
document = result.scalars().first()
|
||||
|
||||
assert document.content == "Mocked summary."
|
||||
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize", "patched_embed_text", "patched_chunk_text"
|
||||
)
|
||||
async def test_chunks_written_to_db(db_session, db_search_space, db_user, mocker):
|
||||
"""Chunks derived from the source markdown are persisted in the DB."""
|
||||
await index_uploaded_file(
|
||||
markdown_content="## Hello\n\nSome content.",
|
||||
filename="test.pdf",
|
||||
etl_service="UNSTRUCTURED",
|
||||
search_space_id=db_search_space.id,
|
||||
user_id=str(db_user.id),
|
||||
session=db_session,
|
||||
llm=mocker.Mock(),
|
||||
)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.search_space_id == db_search_space.id)
|
||||
)
|
||||
document = result.scalars().first()
|
||||
|
||||
chunks_result = await db_session.execute(
|
||||
select(Chunk).filter(Chunk.document_id == document.id)
|
||||
)
|
||||
chunks = chunks_result.scalars().all()
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].content == "Test chunk content."
|
||||
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize_raises", "patched_embed_text", "patched_chunk_text"
|
||||
)
|
||||
async def test_raises_on_indexing_failure(db_session, db_search_space, db_user, mocker):
|
||||
"""RuntimeError is raised when the indexing step fails so the caller can fire a failure notification."""
|
||||
with pytest.raises(RuntimeError):
|
||||
await index_uploaded_file(
|
||||
markdown_content="## Hello\n\nSome content.",
|
||||
filename="test.pdf",
|
||||
etl_service="UNSTRUCTURED",
|
||||
search_space_id=db_search_space.id,
|
||||
user_id=str(db_user.id),
|
||||
session=db_session,
|
||||
llm=mocker.Mock(),
|
||||
)
|
||||
|
|
@ -0,0 +1,341 @@
|
|||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.db import Chunk, Document, DocumentStatus
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
|
||||
_EMBEDDING_DIM = app_config.embedding_model_instance.dimension
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize", "patched_embed_text", "patched_chunk_text"
|
||||
)
|
||||
async def test_sets_status_ready(
|
||||
db_session,
|
||||
db_search_space,
|
||||
make_connector_document,
|
||||
mocker,
|
||||
):
|
||||
"""Document status is READY after successful indexing."""
|
||||
connector_doc = make_connector_document(search_space_id=db_search_space.id)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([connector_doc])
|
||||
document = prepared[0]
|
||||
document_id = document.id
|
||||
|
||||
await service.index(document, connector_doc, llm=mocker.Mock())
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.READY)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize", "patched_embed_text", "patched_chunk_text"
|
||||
)
|
||||
async def test_content_is_summary_when_should_summarize_true(
|
||||
db_session,
|
||||
db_search_space,
|
||||
make_connector_document,
|
||||
mocker,
|
||||
):
|
||||
"""Document content is set to the LLM-generated summary when should_summarize=True."""
|
||||
connector_doc = make_connector_document(search_space_id=db_search_space.id)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([connector_doc])
|
||||
document = prepared[0]
|
||||
document_id = document.id
|
||||
|
||||
await service.index(document, connector_doc, llm=mocker.Mock())
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert reloaded.content == "Mocked summary."
|
||||
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize", "patched_embed_text", "patched_chunk_text"
|
||||
)
|
||||
async def test_content_is_source_markdown_when_should_summarize_false(
|
||||
db_session,
|
||||
db_search_space,
|
||||
make_connector_document,
|
||||
):
|
||||
"""Document content is set to source_markdown verbatim when should_summarize=False."""
|
||||
connector_doc = make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
should_summarize=False,
|
||||
source_markdown="## Raw content",
|
||||
)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([connector_doc])
|
||||
document = prepared[0]
|
||||
document_id = document.id
|
||||
|
||||
await service.index(document, connector_doc, llm=None)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert reloaded.content == "## Raw content"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize", "patched_embed_text", "patched_chunk_text"
|
||||
)
|
||||
async def test_chunks_written_to_db(
|
||||
db_session,
|
||||
db_search_space,
|
||||
make_connector_document,
|
||||
mocker,
|
||||
):
|
||||
"""Chunks derived from source_markdown are persisted in the DB."""
|
||||
connector_doc = make_connector_document(search_space_id=db_search_space.id)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([connector_doc])
|
||||
document = prepared[0]
|
||||
document_id = document.id
|
||||
|
||||
await service.index(document, connector_doc, llm=mocker.Mock())
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Chunk).filter(Chunk.document_id == document_id)
|
||||
)
|
||||
chunks = result.scalars().all()
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].content == "Test chunk content."
|
||||
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize", "patched_embed_text", "patched_chunk_text"
|
||||
)
|
||||
async def test_embedding_written_to_db(
|
||||
db_session,
|
||||
db_search_space,
|
||||
make_connector_document,
|
||||
mocker,
|
||||
):
|
||||
"""Document embedding vector is persisted in the DB after indexing."""
|
||||
connector_doc = make_connector_document(search_space_id=db_search_space.id)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([connector_doc])
|
||||
document = prepared[0]
|
||||
document_id = document.id
|
||||
|
||||
await service.index(document, connector_doc, llm=mocker.Mock())
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert reloaded.embedding is not None
|
||||
assert len(reloaded.embedding) == _EMBEDDING_DIM
|
||||
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize", "patched_embed_text", "patched_chunk_text"
|
||||
)
|
||||
async def test_updated_at_advances_after_indexing(
|
||||
db_session,
|
||||
db_search_space,
|
||||
make_connector_document,
|
||||
mocker,
|
||||
):
|
||||
"""updated_at timestamp is later after indexing than it was at prepare time."""
|
||||
connector_doc = make_connector_document(search_space_id=db_search_space.id)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([connector_doc])
|
||||
document = prepared[0]
|
||||
document_id = document.id
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
updated_at_pending = result.scalars().first().updated_at
|
||||
|
||||
await service.index(document, connector_doc, llm=mocker.Mock())
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
updated_at_ready = result.scalars().first().updated_at
|
||||
|
||||
assert updated_at_ready > updated_at_pending
|
||||
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize", "patched_embed_text", "patched_chunk_text"
|
||||
)
|
||||
async def test_no_llm_falls_back_to_source_markdown(
|
||||
db_session,
|
||||
db_search_space,
|
||||
make_connector_document,
|
||||
):
|
||||
"""When llm=None and no fallback_summary, content falls back to source_markdown."""
|
||||
connector_doc = make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
should_summarize=True,
|
||||
source_markdown="## Fallback content",
|
||||
)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([connector_doc])
|
||||
document = prepared[0]
|
||||
document_id = document.id
|
||||
|
||||
await service.index(document, connector_doc, llm=None)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.READY)
|
||||
assert reloaded.content == "## Fallback content"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize", "patched_embed_text", "patched_chunk_text"
|
||||
)
|
||||
async def test_fallback_summary_used_when_llm_unavailable(
|
||||
db_session,
|
||||
db_search_space,
|
||||
make_connector_document,
|
||||
):
|
||||
"""fallback_summary is used as content when llm=None and should_summarize=True."""
|
||||
connector_doc = make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
should_summarize=True,
|
||||
source_markdown="## Full raw content",
|
||||
fallback_summary="Short pre-built summary.",
|
||||
)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([connector_doc])
|
||||
document_id = prepared[0].id
|
||||
|
||||
await service.index(prepared[0], connector_doc, llm=None)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.READY)
|
||||
assert reloaded.content == "Short pre-built summary."
|
||||
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize", "patched_embed_text", "patched_chunk_text"
|
||||
)
|
||||
async def test_reindex_replaces_old_chunks(
|
||||
db_session,
|
||||
db_search_space,
|
||||
make_connector_document,
|
||||
mocker,
|
||||
):
|
||||
"""Re-indexing a document replaces its old chunks rather than appending."""
|
||||
connector_doc = make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
source_markdown="## v1",
|
||||
)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([connector_doc])
|
||||
document = prepared[0]
|
||||
document_id = document.id
|
||||
|
||||
await service.index(document, connector_doc, llm=mocker.Mock())
|
||||
|
||||
updated_doc = make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
source_markdown="## v2",
|
||||
)
|
||||
re_prepared = await service.prepare_for_indexing([updated_doc])
|
||||
await service.index(re_prepared[0], updated_doc, llm=mocker.Mock())
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Chunk).filter(Chunk.document_id == document_id)
|
||||
)
|
||||
chunks = result.scalars().all()
|
||||
|
||||
assert len(chunks) == 1
|
||||
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize_raises", "patched_embed_text", "patched_chunk_text"
|
||||
)
|
||||
async def test_llm_error_sets_status_failed(
|
||||
db_session,
|
||||
db_search_space,
|
||||
make_connector_document,
|
||||
mocker,
|
||||
):
|
||||
"""Document status is FAILED when the LLM raises during indexing."""
|
||||
connector_doc = make_connector_document(search_space_id=db_search_space.id)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([connector_doc])
|
||||
document = prepared[0]
|
||||
document_id = document.id
|
||||
|
||||
await service.index(document, connector_doc, llm=mocker.Mock())
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.FAILED)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize_raises", "patched_embed_text", "patched_chunk_text"
|
||||
)
|
||||
async def test_llm_error_leaves_no_partial_data(
|
||||
db_session,
|
||||
db_search_space,
|
||||
make_connector_document,
|
||||
mocker,
|
||||
):
|
||||
"""A failed indexing attempt leaves no partial embedding or chunks in the DB."""
|
||||
connector_doc = make_connector_document(search_space_id=db_search_space.id)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([connector_doc])
|
||||
document = prepared[0]
|
||||
document_id = document.id
|
||||
|
||||
await service.index(document, connector_doc, llm=mocker.Mock())
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert reloaded.embedding is None
|
||||
assert reloaded.content == "Pending..."
|
||||
|
||||
chunks_result = await db_session.execute(
|
||||
select(Chunk).filter(Chunk.document_id == document_id)
|
||||
)
|
||||
assert chunks_result.scalars().all() == []
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue