mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 00:36:31 +02:00
Merge pull request #837 from CREDO23/test-document-creation
[Refactor] Core document creation / indexing pipeline with unit and e2e tests
This commit is contained in:
commit
2e99f1e853
38 changed files with 5535 additions and 3342 deletions
112
.cursor/skills/tdd/SKILL.md
Normal file
112
.cursor/skills/tdd/SKILL.md
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
---
|
||||
name: tdd
|
||||
description: Strict Python TDD workflow using pytest (Red-Green-Refactor).
|
||||
---
|
||||
|
||||
---
|
||||
name: tdd
|
||||
description: Test-driven development with red-green-refactor loop. Use when user wants to build features or fix bugs using TDD, mentions "red-green-refactor", wants integration tests, or asks for test-first development.
|
||||
---
|
||||
|
||||
# Test-Driven Development
|
||||
|
||||
## Philosophy
|
||||
|
||||
**Core principle**: Tests should verify behavior through public interfaces, not implementation details. Code can change entirely; tests shouldn't.
|
||||
|
||||
**Good tests** are integration-style: they exercise real code paths through public APIs. They describe _what_ the system does, not _how_ it does it. A good test reads like a specification - "user can checkout with valid cart" tells you exactly what capability exists. These tests survive refactors because they don't care about internal structure.
|
||||
|
||||
**Bad tests** are coupled to implementation. They mock internal collaborators, test private methods, or verify through external means (like querying a database directly instead of using the interface). The warning sign: your test breaks when you refactor, but behavior hasn't changed. If you rename an internal function and tests fail, those tests were testing implementation, not behavior.
|
||||
|
||||
See [tests.md](tests.md) for examples and [mocking.md](mocking.md) for mocking guidelines.
|
||||
|
||||
## Anti-Pattern: Horizontal Slices
|
||||
|
||||
**DO NOT write all tests first, then all implementation.** This is "horizontal slicing" - treating RED as "write all tests" and GREEN as "write all code."
|
||||
|
||||
This produces **crap tests**:
|
||||
|
||||
- Tests written in bulk test _imagined_ behavior, not _actual_ behavior
|
||||
- You end up testing the _shape_ of things (data structures, function signatures) rather than user-facing behavior
|
||||
- Tests become insensitive to real changes - they pass when behavior breaks, fail when behavior is fine
|
||||
- You outrun your headlights, committing to test structure before understanding the implementation
|
||||
|
||||
**Correct approach**: Vertical slices via tracer bullets. One test → one implementation → repeat. Each test responds to what you learned from the previous cycle. Because you just wrote the code, you know exactly what behavior matters and how to verify it.
|
||||
|
||||
```
|
||||
WRONG (horizontal):
|
||||
RED: test1, test2, test3, test4, test5
|
||||
GREEN: impl1, impl2, impl3, impl4, impl5
|
||||
|
||||
RIGHT (vertical):
|
||||
RED→GREEN: test1→impl1
|
||||
RED→GREEN: test2→impl2
|
||||
RED→GREEN: test3→impl3
|
||||
...
|
||||
```
|
||||
|
||||
## Workflow
|
||||
|
||||
### 1. Planning
|
||||
|
||||
Before writing any code:
|
||||
|
||||
- [ ] Confirm with user what interface changes are needed
|
||||
- [ ] Confirm with user which behaviors to test (prioritize)
|
||||
- [ ] Identify opportunities for [deep modules](deep-modules.md) (small interface, deep implementation)
|
||||
- [ ] Design interfaces for [testability](interface-design.md)
|
||||
- [ ] List the behaviors to test (not implementation steps)
|
||||
- [ ] Get user approval on the plan
|
||||
|
||||
Ask: "What should the public interface look like? Which behaviors are most important to test?"
|
||||
|
||||
**You can't test everything.** Confirm with the user exactly which behaviors matter most. Focus testing effort on critical paths and complex logic, not every possible edge case.
|
||||
|
||||
### 2. Tracer Bullet
|
||||
|
||||
Write ONE test that confirms ONE thing about the system:
|
||||
|
||||
```
|
||||
RED: Write test for first behavior → test fails
|
||||
GREEN: Write minimal code to pass → test passes
|
||||
```
|
||||
|
||||
This is your tracer bullet - proves the path works end-to-end.
|
||||
|
||||
### 3. Incremental Loop
|
||||
|
||||
For each remaining behavior:
|
||||
|
||||
```
|
||||
RED: Write next test → fails
|
||||
GREEN: Minimal code to pass → passes
|
||||
```
|
||||
|
||||
Rules:
|
||||
|
||||
- One test at a time
|
||||
- Only enough code to pass current test
|
||||
- Don't anticipate future tests
|
||||
- Keep tests focused on observable behavior
|
||||
|
||||
### 4. Refactor
|
||||
|
||||
After all tests pass, look for [refactor candidates](refactoring.md):
|
||||
|
||||
- [ ] Extract duplication
|
||||
- [ ] Deepen modules (move complexity behind simple interfaces)
|
||||
- [ ] Apply SOLID principles where natural
|
||||
- [ ] Consider what new code reveals about existing code
|
||||
- [ ] Run tests after each refactor step
|
||||
|
||||
**Never refactor while RED.** Get to GREEN first.
|
||||
|
||||
## Checklist Per Cycle
|
||||
|
||||
```
|
||||
[ ] Test describes behavior, not implementation
|
||||
[ ] Test uses public interface only
|
||||
[ ] Test would survive internal refactor
|
||||
[ ] Code is minimal for this test
|
||||
[ ] No speculative features added
|
||||
```
|
||||
33
.cursor/skills/tdd/deep-modules.md
Normal file
33
.cursor/skills/tdd/deep-modules.md
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
# Deep Modules
|
||||
|
||||
From "A Philosophy of Software Design":
|
||||
|
||||
**Deep module** = small interface + lots of implementation
|
||||
|
||||
```
|
||||
┌─────────────────────┐
|
||||
│ Small Interface │ ← Few methods, simple params
|
||||
├─────────────────────┤
|
||||
│ │
|
||||
│ │
|
||||
│ Deep Implementation│ ← Complex logic hidden
|
||||
│ │
|
||||
│ │
|
||||
└─────────────────────┘
|
||||
```
|
||||
|
||||
**Shallow module** = large interface + little implementation (avoid)
|
||||
|
||||
```
|
||||
┌─────────────────────────────────┐
|
||||
│ Large Interface │ ← Many methods, complex params
|
||||
├─────────────────────────────────┤
|
||||
│ Thin Implementation │ ← Just passes through
|
||||
└─────────────────────────────────┘
|
||||
```
|
||||
|
||||
When designing interfaces, ask:
|
||||
|
||||
- Can I reduce the number of methods?
|
||||
- Can I simplify the parameters?
|
||||
- Can I hide more complexity inside?
|
||||
33
.cursor/skills/tdd/interface-design.md
Normal file
33
.cursor/skills/tdd/interface-design.md
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
# Interface Design for Testability
|
||||
|
||||
Good interfaces make testing natural:
|
||||
|
||||
1. **Accept dependencies, don't create them**
|
||||
```python
|
||||
# Testable
|
||||
def process_order(order, payment_gateway):
|
||||
pass
|
||||
|
||||
# Hard to test
|
||||
def process_order(order):
|
||||
gateway = StripeGateway()
|
||||
|
||||
```
|
||||
|
||||
|
||||
2. **Return results, don't produce side effects**
|
||||
```python
|
||||
# Testable
|
||||
def calculate_discount(cart) -> float:
|
||||
return discount
|
||||
|
||||
# Hard to test
|
||||
def apply_discount(cart) -> None:
|
||||
cart.total -= discount
|
||||
|
||||
```
|
||||
|
||||
|
||||
3. **Small surface area**
|
||||
* Fewer methods = fewer tests needed
|
||||
* Fewer params = simpler test setup
|
||||
69
.cursor/skills/tdd/mocking.md
Normal file
69
.cursor/skills/tdd/mocking.md
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
|
||||
# When to Mock
|
||||
|
||||
Mock at **system boundaries** only:
|
||||
|
||||
* External APIs (payment, email, etc.)
|
||||
* Databases (sometimes - prefer test DB)
|
||||
* Time/randomness
|
||||
* File system (sometimes)
|
||||
|
||||
Don't mock:
|
||||
|
||||
* Your own classes/modules
|
||||
* Internal collaborators
|
||||
* Anything you control
|
||||
|
||||
## Designing for Mockability
|
||||
|
||||
At system boundaries, design interfaces that are easy to mock:
|
||||
|
||||
**1. Use dependency injection**
|
||||
|
||||
Pass external dependencies in rather than creating them internally:
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
# Easy to mock
|
||||
def process_payment(order, payment_client):
|
||||
return payment_client.charge(order.total)
|
||||
|
||||
# Hard to mock
|
||||
def process_payment(order):
|
||||
client = StripeClient(os.getenv("STRIPE_KEY"))
|
||||
return client.charge(order.total)
|
||||
|
||||
```
|
||||
|
||||
**2. Prefer SDK-style interfaces over generic fetchers**
|
||||
|
||||
Create specific functions for each external operation instead of one generic function with conditional logic:
|
||||
|
||||
```python
|
||||
import requests
|
||||
|
||||
# GOOD: Each function is independently mockable
|
||||
class UserAPI:
|
||||
def get_user(self, user_id):
|
||||
return requests.get(f"/users/{user_id}")
|
||||
|
||||
def get_orders(self, user_id):
|
||||
return requests.get(f"/users/{user_id}/orders")
|
||||
|
||||
def create_order(self, data):
|
||||
return requests.post("/orders", json=data)
|
||||
|
||||
# BAD: Mocking requires conditional logic inside the mock
|
||||
class GenericAPI:
|
||||
def fetch(self, endpoint, method="GET", data=None):
|
||||
return requests.request(method, endpoint, json=data)
|
||||
|
||||
```
|
||||
|
||||
The SDK approach means:
|
||||
|
||||
* Each mock returns one specific shape
|
||||
* No conditional logic in test setup
|
||||
* Easier to see which endpoints a test exercises
|
||||
* Type safety per endpoint
|
||||
10
.cursor/skills/tdd/refactoring.md
Normal file
10
.cursor/skills/tdd/refactoring.md
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
# Refactor Candidates
|
||||
|
||||
After TDD cycle, look for:
|
||||
|
||||
- **Duplication** → Extract function/class
|
||||
- **Long methods** → Break into private helpers (keep tests on public interface)
|
||||
- **Shallow modules** → Combine or deepen
|
||||
- **Feature envy** → Move logic to where data lives
|
||||
- **Primitive obsession** → Introduce value objects
|
||||
- **Existing code** the new code reveals as problematic
|
||||
60
.cursor/skills/tdd/tests.md
Normal file
60
.cursor/skills/tdd/tests.md
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
# Good and Bad Tests
|
||||
|
||||
## Good Tests
|
||||
|
||||
**Integration-style**: Test through real interfaces, not mocks of internal parts.
|
||||
|
||||
```python
|
||||
# GOOD: Tests observable behavior
|
||||
def test_user_can_checkout_with_valid_cart():
|
||||
cart = create_cart()
|
||||
cart.add(product)
|
||||
result = checkout(cart, payment_method)
|
||||
assert result.status == "confirmed"
|
||||
|
||||
```
|
||||
|
||||
Characteristics:
|
||||
|
||||
* Tests behavior users/callers care about
|
||||
* Uses public API only
|
||||
* Survives internal refactors
|
||||
* Describes WHAT, not HOW
|
||||
* One logical assertion per test
|
||||
|
||||
## Bad Tests
|
||||
|
||||
**Implementation-detail tests**: Coupled to internal structure.
|
||||
|
||||
```python
|
||||
# BAD: Tests implementation details
|
||||
def test_checkout_calls_payment_service_process():
|
||||
mock_payment = MagicMock()
|
||||
checkout(cart, mock_payment)
|
||||
mock_payment.process.assert_called_with(cart.total)
|
||||
|
||||
```
|
||||
|
||||
Red flags:
|
||||
|
||||
* Mocking internal collaborators
|
||||
* Testing private methods
|
||||
* Asserting on call counts/order
|
||||
* Test breaks when refactoring without behavior change
|
||||
* Test name describes HOW not WHAT
|
||||
* Verifying through external means instead of interface
|
||||
|
||||
```python
|
||||
# BAD: Bypasses interface to verify
|
||||
def test_create_user_saves_to_database():
|
||||
create_user({"name": "Alice"})
|
||||
row = db.query("SELECT * FROM users WHERE name = ?", ["Alice"])
|
||||
assert row is not None
|
||||
|
||||
# GOOD: Verifies through interface
|
||||
def test_create_user_makes_user_retrievable():
|
||||
user = create_user({"name": "Alice"})
|
||||
retrieved = get_user(user.id)
|
||||
assert retrieved.name == "Alice"
|
||||
|
||||
```
|
||||
0
surfsense_backend/app/indexing_pipeline/__init__.py
Normal file
0
surfsense_backend/app/indexing_pipeline/__init__.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import DocumentStatus, DocumentType
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
|
||||
|
||||
async def index_uploaded_file(
|
||||
markdown_content: str,
|
||||
filename: str,
|
||||
etl_service: str,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
session: AsyncSession,
|
||||
llm,
|
||||
) -> None:
|
||||
connector_doc = ConnectorDocument(
|
||||
title=filename,
|
||||
source_markdown=markdown_content,
|
||||
unique_id=filename,
|
||||
document_type=DocumentType.FILE,
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=user_id,
|
||||
connector_id=None,
|
||||
should_summarize=True,
|
||||
should_use_code_chunker=False,
|
||||
fallback_summary=markdown_content[:4000],
|
||||
metadata={
|
||||
"FILE_NAME": filename,
|
||||
"ETL_SERVICE": etl_service,
|
||||
},
|
||||
)
|
||||
|
||||
service = IndexingPipelineService(session)
|
||||
documents = await service.prepare_for_indexing([connector_doc])
|
||||
|
||||
if not documents:
|
||||
raise RuntimeError("prepare_for_indexing returned no documents")
|
||||
|
||||
indexed = await service.index(documents[0], connector_doc, llm)
|
||||
|
||||
if not DocumentStatus.is_state(indexed.status, DocumentStatus.READY):
|
||||
raise RuntimeError(indexed.status.get("reason", "Indexing failed"))
|
||||
|
||||
indexed.content_needs_reindexing = False
|
||||
await session.commit()
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.db import DocumentType
|
||||
|
||||
|
||||
class ConnectorDocument(BaseModel):
|
||||
"""Canonical data transfer object produced by connector adapters and consumed by the indexing pipeline."""
|
||||
title: str
|
||||
source_markdown: str
|
||||
unique_id: str
|
||||
document_type: DocumentType
|
||||
search_space_id: int = Field(gt=0)
|
||||
should_summarize: bool = True
|
||||
should_use_code_chunker: bool = False
|
||||
fallback_summary: str | None = None
|
||||
metadata: dict = {}
|
||||
connector_id: int | None = None
|
||||
created_by_id: str
|
||||
|
||||
@field_validator("title", "source_markdown", "unique_id", "created_by_id")
|
||||
@classmethod
|
||||
def not_empty(cls, v: str, info) -> str:
|
||||
if not v.strip():
|
||||
raise ValueError(f"{info.field_name} must not be empty or whitespace")
|
||||
return v
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
from app.config import config
|
||||
|
||||
|
||||
def chunk_text(text: str, use_code_chunker: bool = False) -> list[str]:
|
||||
"""Chunk a text string using the configured chunker and return the chunk texts."""
|
||||
chunker = config.code_chunker_instance if use_code_chunker else config.chunker_instance
|
||||
return [c.text for c in chunker.chunk(text)]
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
from app.config import config
|
||||
|
||||
|
||||
def embed_text(text: str) -> list[float]:
|
||||
"""Embed a single text string using the configured embedding model."""
|
||||
return config.embedding_model_instance.embed(text)
|
||||
15
surfsense_backend/app/indexing_pipeline/document_hashing.py
Normal file
15
surfsense_backend/app/indexing_pipeline/document_hashing.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
import hashlib
|
||||
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
|
||||
|
||||
def compute_unique_identifier_hash(doc: ConnectorDocument) -> str:
|
||||
"""Return a stable SHA-256 hash identifying a document by its source identity."""
|
||||
combined = f"{doc.document_type.value}:{doc.unique_id}:{doc.search_space_id}"
|
||||
return hashlib.sha256(combined.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def compute_content_hash(doc: ConnectorDocument) -> str:
|
||||
"""Return a SHA-256 hash of the document's content scoped to its search space."""
|
||||
combined = f"{doc.search_space_id}:{doc.source_markdown}"
|
||||
return hashlib.sha256(combined.encode("utf-8")).hexdigest()
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import object_session
|
||||
from sqlalchemy.orm.attributes import set_committed_value
|
||||
|
||||
from app.db import Document, DocumentStatus
|
||||
|
||||
|
||||
async def rollback_and_persist_failure(
|
||||
session: AsyncSession, document: Document, message: str
|
||||
) -> None:
|
||||
"""Roll back the current transaction and best-effort persist a failed status.
|
||||
|
||||
Called exclusively from except blocks — must never raise, or the new exception
|
||||
would chain with the original and mask it entirely.
|
||||
"""
|
||||
try:
|
||||
await session.rollback()
|
||||
except Exception:
|
||||
return # Session is completely dead; nothing further we can do.
|
||||
try:
|
||||
await session.refresh(document)
|
||||
document.updated_at = datetime.now(UTC)
|
||||
document.status = DocumentStatus.failed(message)
|
||||
await session.commit()
|
||||
except Exception:
|
||||
pass # Best-effort; document will be retried on the next sync.
|
||||
|
||||
|
||||
def attach_chunks_to_document(document: Document, chunks: list) -> None:
|
||||
"""Assign chunks to a document without triggering SQLAlchemy async lazy loading."""
|
||||
set_committed_value(document, "chunks", chunks)
|
||||
session = object_session(document)
|
||||
if session is not None:
|
||||
if document.id is not None:
|
||||
for chunk in chunks:
|
||||
chunk.document_id = document.id
|
||||
session.add_all(chunks)
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
from app.prompts import SUMMARY_PROMPT_TEMPLATE
|
||||
from app.utils.document_converters import optimize_content_for_context_window
|
||||
|
||||
|
||||
async def summarize_document(source_markdown: str, llm, metadata: dict | None = None) -> str:
|
||||
"""Generate a text summary of a document using an LLM, prefixed with metadata when provided."""
|
||||
model_name = getattr(llm, "model", "gpt-3.5-turbo")
|
||||
optimized_content = optimize_content_for_context_window(
|
||||
source_markdown, metadata, model_name
|
||||
)
|
||||
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | llm
|
||||
content_with_metadata = (
|
||||
f"<DOCUMENT><DOCUMENT_METADATA>\n\n{metadata}\n\n</DOCUMENT_METADATA>"
|
||||
f"\n\n<DOCUMENT_CONTENT>\n\n{optimized_content}\n\n</DOCUMENT_CONTENT></DOCUMENT>"
|
||||
)
|
||||
summary_result = await summary_chain.ainvoke({"document": content_with_metadata})
|
||||
summary_content = summary_result.content
|
||||
|
||||
if metadata:
|
||||
metadata_parts = ["# DOCUMENT METADATA"]
|
||||
for key, value in metadata.items():
|
||||
if value:
|
||||
metadata_parts.append(f"**{key.replace('_', ' ').title()}:** {value}")
|
||||
metadata_section = "\n".join(metadata_parts)
|
||||
return f"{metadata_section}\n\n# DOCUMENT SUMMARY\n\n{summary_content}"
|
||||
|
||||
return summary_content
|
||||
121
surfsense_backend/app/indexing_pipeline/exceptions.py
Normal file
121
surfsense_backend/app/indexing_pipeline/exceptions.py
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
from litellm.exceptions import (
|
||||
APIConnectionError,
|
||||
APIResponseValidationError,
|
||||
AuthenticationError,
|
||||
BadGatewayError,
|
||||
BadRequestError,
|
||||
InternalServerError,
|
||||
NotFoundError,
|
||||
PermissionDeniedError,
|
||||
RateLimitError,
|
||||
ServiceUnavailableError,
|
||||
Timeout,
|
||||
UnprocessableEntityError,
|
||||
)
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
# Tuples for use directly in except clauses.
|
||||
RETRYABLE_LLM_ERRORS = (
|
||||
RateLimitError,
|
||||
Timeout,
|
||||
ServiceUnavailableError,
|
||||
BadGatewayError,
|
||||
InternalServerError,
|
||||
APIConnectionError,
|
||||
)
|
||||
|
||||
PERMANENT_LLM_ERRORS = (
|
||||
AuthenticationError,
|
||||
PermissionDeniedError,
|
||||
NotFoundError,
|
||||
BadRequestError,
|
||||
UnprocessableEntityError,
|
||||
APIResponseValidationError,
|
||||
)
|
||||
|
||||
# (LiteLLMEmbeddings, CohereEmbeddings, GeminiEmbeddings all normalize to RuntimeError).
|
||||
EMBEDDING_ERRORS = (
|
||||
RuntimeError, # local device failure or API backend normalization
|
||||
OSError, # model files missing or corrupted (local backends)
|
||||
MemoryError, # document too large for available RAM
|
||||
)
|
||||
|
||||
|
||||
class PipelineMessages:
|
||||
RATE_LIMIT = "LLM rate limit exceeded. Will retry on next sync."
|
||||
LLM_TIMEOUT = "LLM request timed out. Will retry on next sync."
|
||||
LLM_UNAVAILABLE = "LLM service temporarily unavailable. Will retry on next sync."
|
||||
LLM_BAD_GATEWAY = "LLM gateway error. Will retry on next sync."
|
||||
LLM_SERVER_ERROR = "LLM internal server error. Will retry on next sync."
|
||||
LLM_CONNECTION = "Could not reach the LLM service. Check network connectivity."
|
||||
|
||||
LLM_AUTH = "LLM authentication failed. Check your API key."
|
||||
LLM_PERMISSION = "LLM request denied. Check your account permissions."
|
||||
LLM_NOT_FOUND = "LLM model not found. Check your model configuration."
|
||||
LLM_BAD_REQUEST = "LLM rejected the request. Document content may be invalid."
|
||||
LLM_UNPROCESSABLE = "Document exceeds the LLM context window even after optimization."
|
||||
LLM_RESPONSE = "LLM returned an invalid response."
|
||||
|
||||
EMBEDDING_FAILED = "Embedding failed. Check your embedding model configuration or service."
|
||||
EMBEDDING_MODEL = "Embedding model files are missing or corrupted."
|
||||
EMBEDDING_MEMORY = "Not enough memory to embed this document."
|
||||
|
||||
CHUNKING_OVERFLOW = "Document structure is too deeply nested to chunk."
|
||||
|
||||
|
||||
def safe_exception_message(exc: Exception) -> str:
|
||||
try:
|
||||
return str(exc)
|
||||
except Exception:
|
||||
return "Something went wrong during indexing. Error details could not be retrieved."
|
||||
|
||||
|
||||
def llm_retryable_message(exc: Exception) -> str:
|
||||
try:
|
||||
if isinstance(exc, RateLimitError):
|
||||
return PipelineMessages.RATE_LIMIT
|
||||
if isinstance(exc, Timeout):
|
||||
return PipelineMessages.LLM_TIMEOUT
|
||||
if isinstance(exc, ServiceUnavailableError):
|
||||
return PipelineMessages.LLM_UNAVAILABLE
|
||||
if isinstance(exc, BadGatewayError):
|
||||
return PipelineMessages.LLM_BAD_GATEWAY
|
||||
if isinstance(exc, InternalServerError):
|
||||
return PipelineMessages.LLM_SERVER_ERROR
|
||||
if isinstance(exc, APIConnectionError):
|
||||
return PipelineMessages.LLM_CONNECTION
|
||||
return safe_exception_message(exc)
|
||||
except Exception:
|
||||
return "Something went wrong when calling the LLM."
|
||||
|
||||
|
||||
def llm_permanent_message(exc: Exception) -> str:
|
||||
try:
|
||||
if isinstance(exc, AuthenticationError):
|
||||
return PipelineMessages.LLM_AUTH
|
||||
if isinstance(exc, PermissionDeniedError):
|
||||
return PipelineMessages.LLM_PERMISSION
|
||||
if isinstance(exc, NotFoundError):
|
||||
return PipelineMessages.LLM_NOT_FOUND
|
||||
if isinstance(exc, BadRequestError):
|
||||
return PipelineMessages.LLM_BAD_REQUEST
|
||||
if isinstance(exc, UnprocessableEntityError):
|
||||
return PipelineMessages.LLM_UNPROCESSABLE
|
||||
if isinstance(exc, APIResponseValidationError):
|
||||
return PipelineMessages.LLM_RESPONSE
|
||||
return safe_exception_message(exc)
|
||||
except Exception:
|
||||
return "Something went wrong when calling the LLM."
|
||||
|
||||
|
||||
def embedding_message(exc: Exception) -> str:
|
||||
try:
|
||||
if isinstance(exc, RuntimeError):
|
||||
return PipelineMessages.EMBEDDING_FAILED
|
||||
if isinstance(exc, OSError):
|
||||
return PipelineMessages.EMBEDDING_MODEL
|
||||
if isinstance(exc, MemoryError):
|
||||
return PipelineMessages.EMBEDDING_MEMORY
|
||||
return safe_exception_message(exc)
|
||||
except Exception:
|
||||
return "Something went wrong when generating the embedding."
|
||||
|
|
@ -0,0 +1,237 @@
|
|||
import contextlib
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import Chunk, Document, DocumentStatus
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
from app.indexing_pipeline.document_chunker import chunk_text
|
||||
from app.indexing_pipeline.document_embedder import embed_text
|
||||
from app.indexing_pipeline.document_hashing import (
|
||||
compute_content_hash,
|
||||
compute_unique_identifier_hash,
|
||||
)
|
||||
from app.indexing_pipeline.document_persistence import (
|
||||
attach_chunks_to_document,
|
||||
rollback_and_persist_failure,
|
||||
)
|
||||
from app.indexing_pipeline.document_summarizer import summarize_document
|
||||
from app.indexing_pipeline.exceptions import (
|
||||
EMBEDDING_ERRORS,
|
||||
PERMANENT_LLM_ERRORS,
|
||||
RETRYABLE_LLM_ERRORS,
|
||||
IntegrityError,
|
||||
PipelineMessages,
|
||||
embedding_message,
|
||||
llm_permanent_message,
|
||||
llm_retryable_message,
|
||||
safe_exception_message,
|
||||
)
|
||||
from app.indexing_pipeline.pipeline_logger import (
|
||||
PipelineLogContext,
|
||||
log_batch_aborted,
|
||||
log_chunking_overflow,
|
||||
log_doc_skipped_unknown,
|
||||
log_document_queued,
|
||||
log_document_requeued,
|
||||
log_document_updated,
|
||||
log_embedding_error,
|
||||
log_index_started,
|
||||
log_index_success,
|
||||
log_permanent_llm_error,
|
||||
log_race_condition,
|
||||
log_retryable_llm_error,
|
||||
log_unexpected_error,
|
||||
)
|
||||
|
||||
|
||||
class IndexingPipelineService:
|
||||
"""Single pipeline for indexing connector documents. All connectors use this service."""
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self.session = session
|
||||
|
||||
async def prepare_for_indexing(
|
||||
self, connector_docs: list[ConnectorDocument]
|
||||
) -> list[Document]:
|
||||
"""
|
||||
Persist new documents and detect changes, returning only those that need indexing.
|
||||
"""
|
||||
documents = []
|
||||
seen_hashes: set[str] = set()
|
||||
batch_ctx = PipelineLogContext(
|
||||
connector_id=connector_docs[0].connector_id if connector_docs else 0,
|
||||
search_space_id=connector_docs[0].search_space_id if connector_docs else 0,
|
||||
unique_id="batch",
|
||||
)
|
||||
|
||||
for connector_doc in connector_docs:
|
||||
ctx = PipelineLogContext(
|
||||
connector_id=connector_doc.connector_id,
|
||||
search_space_id=connector_doc.search_space_id,
|
||||
unique_id=connector_doc.unique_id,
|
||||
)
|
||||
try:
|
||||
unique_identifier_hash = compute_unique_identifier_hash(connector_doc)
|
||||
content_hash = compute_content_hash(connector_doc)
|
||||
|
||||
if unique_identifier_hash in seen_hashes:
|
||||
continue
|
||||
seen_hashes.add(unique_identifier_hash)
|
||||
|
||||
result = await self.session.execute(
|
||||
select(Document).filter(
|
||||
Document.unique_identifier_hash == unique_identifier_hash
|
||||
)
|
||||
)
|
||||
existing = result.scalars().first()
|
||||
|
||||
if existing is not None:
|
||||
if existing.content_hash == content_hash:
|
||||
if existing.title != connector_doc.title:
|
||||
existing.title = connector_doc.title
|
||||
existing.updated_at = datetime.now(UTC)
|
||||
if not DocumentStatus.is_state(
|
||||
existing.status, DocumentStatus.READY
|
||||
):
|
||||
existing.status = DocumentStatus.pending()
|
||||
existing.updated_at = datetime.now(UTC)
|
||||
documents.append(existing)
|
||||
log_document_requeued(ctx)
|
||||
continue
|
||||
|
||||
existing.title = connector_doc.title
|
||||
existing.content_hash = content_hash
|
||||
existing.source_markdown = connector_doc.source_markdown
|
||||
existing.document_metadata = connector_doc.metadata
|
||||
existing.updated_at = datetime.now(UTC)
|
||||
existing.status = DocumentStatus.pending()
|
||||
documents.append(existing)
|
||||
log_document_updated(ctx)
|
||||
continue
|
||||
|
||||
duplicate = await self.session.execute(
|
||||
select(Document).filter(Document.content_hash == content_hash)
|
||||
)
|
||||
if duplicate.scalars().first() is not None:
|
||||
continue
|
||||
|
||||
document = Document(
|
||||
title=connector_doc.title,
|
||||
document_type=connector_doc.document_type,
|
||||
content="Pending...",
|
||||
content_hash=content_hash,
|
||||
unique_identifier_hash=unique_identifier_hash,
|
||||
source_markdown=connector_doc.source_markdown,
|
||||
document_metadata=connector_doc.metadata,
|
||||
search_space_id=connector_doc.search_space_id,
|
||||
connector_id=connector_doc.connector_id,
|
||||
created_by_id=connector_doc.created_by_id,
|
||||
updated_at=datetime.now(UTC),
|
||||
status=DocumentStatus.pending(),
|
||||
)
|
||||
self.session.add(document)
|
||||
documents.append(document)
|
||||
log_document_queued(ctx)
|
||||
|
||||
except Exception as e:
|
||||
log_doc_skipped_unknown(ctx, e)
|
||||
|
||||
try:
|
||||
await self.session.commit()
|
||||
return documents
|
||||
except IntegrityError:
|
||||
# A concurrent worker committed a document with the same content_hash
|
||||
# or unique_identifier_hash between our check and our INSERT.
|
||||
# The document already exists — roll back and let the next sync run handle it.
|
||||
log_race_condition(batch_ctx)
|
||||
await self.session.rollback()
|
||||
return []
|
||||
except Exception as e:
|
||||
log_batch_aborted(batch_ctx, e)
|
||||
await self.session.rollback()
|
||||
return []
|
||||
|
||||
async def index(
|
||||
self, document: Document, connector_doc: ConnectorDocument, llm
|
||||
) -> Document:
|
||||
"""
|
||||
Run summarization, embedding, and chunking for a document and persist the results.
|
||||
"""
|
||||
ctx = PipelineLogContext(
|
||||
connector_id=connector_doc.connector_id,
|
||||
search_space_id=connector_doc.search_space_id,
|
||||
unique_id=connector_doc.unique_id,
|
||||
doc_id=document.id,
|
||||
)
|
||||
try:
|
||||
log_index_started(ctx)
|
||||
document.status = DocumentStatus.processing()
|
||||
await self.session.commit()
|
||||
|
||||
if connector_doc.should_summarize and llm is not None:
|
||||
content = await summarize_document(
|
||||
connector_doc.source_markdown, llm, connector_doc.metadata
|
||||
)
|
||||
elif connector_doc.should_summarize and connector_doc.fallback_summary:
|
||||
content = connector_doc.fallback_summary
|
||||
else:
|
||||
content = connector_doc.source_markdown
|
||||
|
||||
embedding = embed_text(content)
|
||||
|
||||
await self.session.execute(
|
||||
delete(Chunk).where(Chunk.document_id == document.id)
|
||||
)
|
||||
|
||||
chunks = [
|
||||
Chunk(content=text, embedding=embed_text(text))
|
||||
for text in chunk_text(
|
||||
connector_doc.source_markdown,
|
||||
use_code_chunker=connector_doc.should_use_code_chunker,
|
||||
)
|
||||
]
|
||||
|
||||
document.content = content
|
||||
document.embedding = embedding
|
||||
attach_chunks_to_document(document, chunks)
|
||||
document.updated_at = datetime.now(UTC)
|
||||
document.status = DocumentStatus.ready()
|
||||
await self.session.commit()
|
||||
log_index_success(ctx, chunk_count=len(chunks))
|
||||
|
||||
except RETRYABLE_LLM_ERRORS as e:
|
||||
log_retryable_llm_error(ctx, e)
|
||||
await rollback_and_persist_failure(
|
||||
self.session, document, llm_retryable_message(e)
|
||||
)
|
||||
|
||||
except PERMANENT_LLM_ERRORS as e:
|
||||
log_permanent_llm_error(ctx, e)
|
||||
await rollback_and_persist_failure(
|
||||
self.session, document, llm_permanent_message(e)
|
||||
)
|
||||
|
||||
except RecursionError as e:
|
||||
log_chunking_overflow(ctx, e)
|
||||
await rollback_and_persist_failure(
|
||||
self.session, document, PipelineMessages.CHUNKING_OVERFLOW
|
||||
)
|
||||
|
||||
except EMBEDDING_ERRORS as e:
|
||||
log_embedding_error(ctx, e)
|
||||
await rollback_and_persist_failure(
|
||||
self.session, document, embedding_message(e)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
log_unexpected_error(ctx, e)
|
||||
await rollback_and_persist_failure(
|
||||
self.session, document, safe_exception_message(e)
|
||||
)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
await self.session.refresh(document)
|
||||
|
||||
return document
|
||||
118
surfsense_backend/app/indexing_pipeline/pipeline_logger.py
Normal file
118
surfsense_backend/app/indexing_pipeline/pipeline_logger.py
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineLogContext:
|
||||
connector_id: int | None
|
||||
search_space_id: int
|
||||
unique_id: str # always available from ConnectorDocument
|
||||
doc_id: int | None = None # set once the DB row exists (index phase only)
|
||||
|
||||
|
||||
class LogMessages:
|
||||
# prepare_for_indexing
|
||||
DOCUMENT_QUEUED = "New document queued for indexing."
|
||||
DOCUMENT_UPDATED = "Document content changed, re-queued for indexing."
|
||||
DOCUMENT_REQUEUED = "Stuck document re-queued for indexing."
|
||||
DOC_SKIPPED_UNKNOWN = "Unexpected error — document skipped."
|
||||
BATCH_ABORTED = "Fatal DB error — aborting prepare batch."
|
||||
RACE_CONDITION = "Concurrent worker beat us to the commit — rolling back batch."
|
||||
|
||||
# index
|
||||
INDEX_STARTED = "Document indexing started."
|
||||
INDEX_SUCCESS = "Document indexed successfully."
|
||||
LLM_RETRYABLE = "Retryable LLM error — document marked failed, will retry on next sync."
|
||||
LLM_PERMANENT = "Permanent LLM error — document marked failed."
|
||||
EMBEDDING_FAILED = "Embedding error — document marked failed."
|
||||
CHUNKING_OVERFLOW = "Chunking overflow — document marked failed."
|
||||
UNEXPECTED = "Unexpected error — document marked failed."
|
||||
|
||||
|
||||
def _format_context(ctx: PipelineLogContext) -> str:
|
||||
parts = [
|
||||
f"connector_id={ctx.connector_id}",
|
||||
f"search_space_id={ctx.search_space_id}",
|
||||
f"unique_id={ctx.unique_id}",
|
||||
]
|
||||
if ctx.doc_id is not None:
|
||||
parts.append(f"doc_id={ctx.doc_id}")
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
def _build_message(msg: str, ctx: PipelineLogContext, **extra) -> str:
|
||||
try:
|
||||
parts = [msg, _format_context(ctx)]
|
||||
for key, val in extra.items():
|
||||
parts.append(f"{key}={val}")
|
||||
return " ".join(parts)
|
||||
except Exception:
|
||||
return msg
|
||||
|
||||
|
||||
def _safe_log(level_fn, msg: str, ctx: PipelineLogContext, exc_info=None, **extra) -> None:
|
||||
# Logging must never raise — a broken log call inside an except block would
|
||||
# chain with the original exception and mask it entirely.
|
||||
try:
|
||||
message = _build_message(msg, ctx, **extra)
|
||||
level_fn(message, exc_info=exc_info)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# ── prepare_for_indexing ──────────────────────────────────────────────────────
|
||||
|
||||
def log_document_queued(ctx: PipelineLogContext) -> None:
|
||||
_safe_log(logger.info, LogMessages.DOCUMENT_QUEUED, ctx)
|
||||
|
||||
|
||||
def log_document_updated(ctx: PipelineLogContext) -> None:
|
||||
_safe_log(logger.info, LogMessages.DOCUMENT_UPDATED, ctx)
|
||||
|
||||
|
||||
def log_document_requeued(ctx: PipelineLogContext) -> None:
|
||||
_safe_log(logger.info, LogMessages.DOCUMENT_REQUEUED, ctx)
|
||||
|
||||
|
||||
def log_doc_skipped_unknown(ctx: PipelineLogContext, exc: Exception) -> None:
|
||||
_safe_log(logger.warning, LogMessages.DOC_SKIPPED_UNKNOWN, ctx, exc_info=exc, error=exc)
|
||||
|
||||
|
||||
def log_race_condition(ctx: PipelineLogContext) -> None:
|
||||
_safe_log(logger.warning, LogMessages.RACE_CONDITION, ctx)
|
||||
|
||||
|
||||
def log_batch_aborted(ctx: PipelineLogContext, exc: Exception) -> None:
|
||||
_safe_log(logger.error, LogMessages.BATCH_ABORTED, ctx, exc_info=exc, error=exc)
|
||||
|
||||
|
||||
# ── index ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
def log_index_started(ctx: PipelineLogContext) -> None:
|
||||
_safe_log(logger.info, LogMessages.INDEX_STARTED, ctx)
|
||||
|
||||
|
||||
def log_index_success(ctx: PipelineLogContext, chunk_count: int) -> None:
|
||||
_safe_log(logger.info, LogMessages.INDEX_SUCCESS, ctx, chunk_count=chunk_count)
|
||||
|
||||
|
||||
def log_retryable_llm_error(ctx: PipelineLogContext, exc: Exception) -> None:
|
||||
_safe_log(logger.warning, LogMessages.LLM_RETRYABLE, ctx, exc_info=exc, error=exc)
|
||||
|
||||
|
||||
def log_permanent_llm_error(ctx: PipelineLogContext, exc: Exception) -> None:
|
||||
_safe_log(logger.error, LogMessages.LLM_PERMANENT, ctx, exc_info=exc, error=exc)
|
||||
|
||||
|
||||
def log_embedding_error(ctx: PipelineLogContext, exc: Exception) -> None:
|
||||
_safe_log(logger.error, LogMessages.EMBEDDING_FAILED, ctx, exc_info=exc, error=exc)
|
||||
|
||||
|
||||
def log_chunking_overflow(ctx: PipelineLogContext, exc: Exception) -> None:
|
||||
_safe_log(logger.error, LogMessages.CHUNKING_OVERFLOW, ctx, exc_info=exc, error=exc)
|
||||
|
||||
|
||||
def log_unexpected_error(ctx: PipelineLogContext, exc: Exception) -> None:
|
||||
_safe_log(logger.error, LogMessages.UNEXPECTED, ctx, exc_info=exc, error=exc)
|
||||
|
|
@ -18,6 +18,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.config import config as app_config
|
||||
from app.db import Document, DocumentStatus, DocumentType, Log, Notification
|
||||
from app.indexing_pipeline.adapters.file_upload_adapter import index_uploaded_file
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.notification_service import NotificationService
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
|
|
@ -33,7 +34,6 @@ from .base import (
|
|||
check_document_by_unique_identifier,
|
||||
check_duplicate_document,
|
||||
get_current_timestamp,
|
||||
safe_set_chunks,
|
||||
)
|
||||
from .markdown_processor import add_received_markdown_file_document
|
||||
|
||||
|
|
@ -1863,7 +1863,7 @@ async def process_file_in_background_with_document(
|
|||
)
|
||||
return None
|
||||
|
||||
# ===== STEP 3: Generate embeddings and chunks =====
|
||||
# ===== STEP 3+4: Index via pipeline =====
|
||||
if notification:
|
||||
await NotificationService.document_processing.notify_processing_progress(
|
||||
session, notification, stage="chunking"
|
||||
|
|
@ -1871,57 +1871,22 @@ async def process_file_in_background_with_document(
|
|||
|
||||
user_llm = await get_user_long_context_llm(session, user_id, search_space_id)
|
||||
|
||||
if user_llm:
|
||||
document_metadata = {
|
||||
"file_name": filename,
|
||||
"etl_service": etl_service,
|
||||
"document_type": "File Document",
|
||||
}
|
||||
summary_content, summary_embedding = await generate_document_summary(
|
||||
markdown_content, user_llm, document_metadata
|
||||
)
|
||||
else:
|
||||
# Fallback: use truncated content as summary
|
||||
summary_content = markdown_content[:4000]
|
||||
from app.config import config
|
||||
|
||||
summary_embedding = config.embedding_model_instance.embed(summary_content)
|
||||
|
||||
chunks = await create_document_chunks(markdown_content)
|
||||
|
||||
# ===== STEP 4: Update document to READY =====
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
document.title = filename
|
||||
document.content = summary_content
|
||||
document.content_hash = content_hash
|
||||
document.embedding = summary_embedding
|
||||
document.document_metadata = {
|
||||
"FILE_NAME": filename,
|
||||
"ETL_SERVICE": etl_service or "UNKNOWN",
|
||||
**(document.document_metadata or {}),
|
||||
}
|
||||
flag_modified(document, "document_metadata")
|
||||
|
||||
# Use safe_set_chunks to avoid async issues
|
||||
safe_set_chunks(document, chunks)
|
||||
|
||||
document.source_markdown = markdown_content
|
||||
document.content_needs_reindexing = False
|
||||
document.updated_at = get_current_timestamp()
|
||||
document.status = DocumentStatus.ready() # Shows checkmark in UI
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(document)
|
||||
await index_uploaded_file(
|
||||
markdown_content=markdown_content,
|
||||
filename=filename,
|
||||
etl_service=etl_service,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
llm=user_llm,
|
||||
)
|
||||
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Successfully processed file: {filename}",
|
||||
{
|
||||
"document_id": document.id,
|
||||
"content_hash": content_hash,
|
||||
"file_type": etl_service,
|
||||
"chunks_count": len(chunks),
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -72,6 +72,22 @@ dependencies = [
|
|||
[dependency-groups]
|
||||
dev = [
|
||||
"ruff>=0.12.5",
|
||||
"pytest>=8.0",
|
||||
"pytest-asyncio>=0.25",
|
||||
"pytest-mock>=3.14",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "session"
|
||||
asyncio_default_test_loop_scope = "session"
|
||||
testpaths = ["tests"]
|
||||
markers = [
|
||||
"unit: pure logic tests, no DB or external services",
|
||||
"integration: tests that require a real PostgreSQL database",
|
||||
]
|
||||
filterwarnings = [
|
||||
"ignore::UserWarning:chonkie",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
|
|
|
|||
0
surfsense_backend/tests/__init__.py
Normal file
0
surfsense_backend/tests/__init__.py
Normal file
36
surfsense_backend/tests/conftest.py
Normal file
36
surfsense_backend/tests/conftest.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
import pytest
|
||||
|
||||
from app.db import DocumentType
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_user_id() -> str:
|
||||
return "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_search_space_id() -> int:
|
||||
return 1
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_connector_id() -> int:
|
||||
return 42
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_connector_document():
|
||||
def _make(**overrides):
|
||||
defaults = {
|
||||
"title": "Test Document",
|
||||
"source_markdown": "## Heading\n\nSome content.",
|
||||
"unique_id": "test-id-001",
|
||||
"document_type": DocumentType.CLICKUP_CONNECTOR,
|
||||
"search_space_id": 1,
|
||||
"connector_id": 1,
|
||||
"created_by_id": "00000000-0000-0000-0000-000000000001",
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return ConnectorDocument(**defaults)
|
||||
return _make
|
||||
0
surfsense_backend/tests/integration/__init__.py
Normal file
0
surfsense_backend/tests/integration/__init__.py
Normal file
164
surfsense_backend/tests/integration/conftest.py
Normal file
164
surfsense_backend/tests/integration/conftest.py
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
|
||||
import os
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from app.db import Base, SearchSpace, SearchSourceConnector, SearchSourceConnectorType
|
||||
from app.db import User
|
||||
from app.db import DocumentType
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
|
||||
_EMBEDDING_DIM = 1024 # must match the Vector() dimension used in DB column creation
|
||||
|
||||
_DEFAULT_TEST_DB = "postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense_test"
|
||||
TEST_DATABASE_URL = os.environ.get("TEST_DATABASE_URL", _DEFAULT_TEST_DB)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def async_engine():
|
||||
engine = create_async_engine(
|
||||
TEST_DATABASE_URL,
|
||||
poolclass=NullPool,
|
||||
echo=False,
|
||||
# Required for asyncpg + savepoints: disables prepared statement cache
|
||||
# to prevent "another operation is in progress" errors during savepoint rollbacks.
|
||||
connect_args={"prepared_statement_cache_size": 0},
|
||||
)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
# drop_all fails on circular FKs (new_chat_threads ↔ public_chat_snapshots).
|
||||
# DROP SCHEMA CASCADE handles this without needing topological sort.
|
||||
async with engine.begin() as conn:
|
||||
await conn.execute(text("DROP SCHEMA public CASCADE"))
|
||||
await conn.execute(text("CREATE SCHEMA public"))
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_session(async_engine) -> AsyncSession:
|
||||
# Bind the session to a connection that holds an outer transaction.
|
||||
# join_transaction_mode="create_savepoint" makes session.commit() release
|
||||
# a SAVEPOINT instead of committing the outer transaction, so the final
|
||||
# transaction.rollback() undoes everything — including commits made by the
|
||||
# service under test — leaving the DB clean for the next test.
|
||||
async with async_engine.connect() as conn:
|
||||
transaction = await conn.begin()
|
||||
async with AsyncSession(
|
||||
bind=conn,
|
||||
expire_on_commit=False,
|
||||
join_transaction_mode="create_savepoint",
|
||||
) as session:
|
||||
yield session
|
||||
await transaction.rollback()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_user(db_session: AsyncSession) -> User:
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="test@surfsense.net",
|
||||
hashed_password="hashed",
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
)
|
||||
db_session.add(user)
|
||||
await db_session.flush()
|
||||
return user
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_connector(db_session: AsyncSession, db_user: User, db_search_space: "SearchSpace") -> SearchSourceConnector:
|
||||
connector = SearchSourceConnector(
|
||||
name="Test Connector",
|
||||
connector_type=SearchSourceConnectorType.CLICKUP_CONNECTOR,
|
||||
config={},
|
||||
search_space_id=db_search_space.id,
|
||||
user_id=db_user.id,
|
||||
)
|
||||
db_session.add(connector)
|
||||
await db_session.flush()
|
||||
return connector
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_search_space(db_session: AsyncSession, db_user: User) -> SearchSpace:
|
||||
space = SearchSpace(
|
||||
name="Test Space",
|
||||
user_id=db_user.id,
|
||||
)
|
||||
db_session.add(space)
|
||||
await db_session.flush()
|
||||
return space
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_summarize(monkeypatch) -> AsyncMock:
|
||||
mock = AsyncMock(return_value="Mocked summary.")
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.indexing_pipeline_service.summarize_document",
|
||||
mock,
|
||||
)
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_summarize_raises(monkeypatch) -> AsyncMock:
|
||||
mock = AsyncMock(side_effect=RuntimeError("LLM unavailable"))
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.indexing_pipeline_service.summarize_document",
|
||||
mock,
|
||||
)
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_embed_text(monkeypatch) -> MagicMock:
|
||||
mock = MagicMock(return_value=[0.1] * _EMBEDDING_DIM)
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.indexing_pipeline_service.embed_text",
|
||||
mock,
|
||||
)
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_chunk_text(monkeypatch) -> MagicMock:
|
||||
mock = MagicMock(return_value=["Test chunk content."])
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.indexing_pipeline_service.chunk_text",
|
||||
mock,
|
||||
)
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_connector_document(db_connector, db_user):
|
||||
"""Integration-scoped override: uses real DB connector and user IDs."""
|
||||
def _make(**overrides):
|
||||
defaults = {
|
||||
"title": "Test Document",
|
||||
"source_markdown": "## Heading\n\nSome content.",
|
||||
"unique_id": "test-id-001",
|
||||
"document_type": DocumentType.CLICKUP_CONNECTOR,
|
||||
"search_space_id": db_connector.search_space_id,
|
||||
"connector_id": db_connector.id,
|
||||
"created_by_id": str(db_user.id),
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return ConnectorDocument(**defaults)
|
||||
return _make
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,91 @@
|
|||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.db import Chunk, Document, DocumentStatus
|
||||
from app.indexing_pipeline.adapters.file_upload_adapter import index_uploaded_file
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
|
||||
async def test_sets_status_ready(db_session, db_search_space, db_user, mocker):
|
||||
"""Document status is READY after successful indexing."""
|
||||
await index_uploaded_file(
|
||||
markdown_content="## Hello\n\nSome content.",
|
||||
filename="test.pdf",
|
||||
etl_service="UNSTRUCTURED",
|
||||
search_space_id=db_search_space.id,
|
||||
user_id=str(db_user.id),
|
||||
session=db_session,
|
||||
llm=mocker.Mock(),
|
||||
)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.search_space_id == db_search_space.id)
|
||||
)
|
||||
document = result.scalars().first()
|
||||
|
||||
assert DocumentStatus.is_state(document.status, DocumentStatus.READY)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
|
||||
async def test_content_is_summary(db_session, db_search_space, db_user, mocker):
|
||||
"""Document content is set to the LLM-generated summary."""
|
||||
await index_uploaded_file(
|
||||
markdown_content="## Hello\n\nSome content.",
|
||||
filename="test.pdf",
|
||||
etl_service="UNSTRUCTURED",
|
||||
search_space_id=db_search_space.id,
|
||||
user_id=str(db_user.id),
|
||||
session=db_session,
|
||||
llm=mocker.Mock(),
|
||||
)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.search_space_id == db_search_space.id)
|
||||
)
|
||||
document = result.scalars().first()
|
||||
|
||||
assert document.content == "Mocked summary."
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
|
||||
async def test_chunks_written_to_db(db_session, db_search_space, db_user, mocker):
|
||||
"""Chunks derived from the source markdown are persisted in the DB."""
|
||||
await index_uploaded_file(
|
||||
markdown_content="## Hello\n\nSome content.",
|
||||
filename="test.pdf",
|
||||
etl_service="UNSTRUCTURED",
|
||||
search_space_id=db_search_space.id,
|
||||
user_id=str(db_user.id),
|
||||
session=db_session,
|
||||
llm=mocker.Mock(),
|
||||
)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.search_space_id == db_search_space.id)
|
||||
)
|
||||
document = result.scalars().first()
|
||||
|
||||
chunks_result = await db_session.execute(
|
||||
select(Chunk).filter(Chunk.document_id == document.id)
|
||||
)
|
||||
chunks = chunks_result.scalars().all()
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].content == "Test chunk content."
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarize_raises", "patched_embed_text", "patched_chunk_text")
|
||||
async def test_raises_on_indexing_failure(db_session, db_search_space, db_user, mocker):
|
||||
"""RuntimeError is raised when the indexing step fails so the caller can fire a failure notification."""
|
||||
with pytest.raises(RuntimeError):
|
||||
await index_uploaded_file(
|
||||
markdown_content="## Hello\n\nSome content.",
|
||||
filename="test.pdf",
|
||||
etl_service="UNSTRUCTURED",
|
||||
search_space_id=db_search_space.id,
|
||||
user_id=str(db_user.id),
|
||||
session=db_session,
|
||||
llm=mocker.Mock(),
|
||||
)
|
||||
|
|
@ -0,0 +1,266 @@
|
|||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.db import Chunk, Document, DocumentStatus
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
|
||||
async def test_sets_status_ready(
|
||||
db_session, db_search_space, make_connector_document, mocker,
|
||||
):
|
||||
"""Document status is READY after successful indexing."""
|
||||
connector_doc = make_connector_document(search_space_id=db_search_space.id)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([connector_doc])
|
||||
document = prepared[0]
|
||||
document_id = document.id
|
||||
|
||||
await service.index(document, connector_doc, llm=mocker.Mock())
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.READY)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
|
||||
async def test_content_is_summary_when_should_summarize_true(
|
||||
db_session, db_search_space, make_connector_document, mocker,
|
||||
):
|
||||
"""Document content is set to the LLM-generated summary when should_summarize=True."""
|
||||
connector_doc = make_connector_document(search_space_id=db_search_space.id)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([connector_doc])
|
||||
document = prepared[0]
|
||||
document_id = document.id
|
||||
|
||||
await service.index(document, connector_doc, llm=mocker.Mock())
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert reloaded.content == "Mocked summary."
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
|
||||
async def test_content_is_source_markdown_when_should_summarize_false(
|
||||
db_session, db_search_space, make_connector_document,
|
||||
):
|
||||
"""Document content is set to source_markdown verbatim when should_summarize=False."""
|
||||
connector_doc = make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
should_summarize=False,
|
||||
source_markdown="## Raw content",
|
||||
)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([connector_doc])
|
||||
document = prepared[0]
|
||||
document_id = document.id
|
||||
|
||||
await service.index(document, connector_doc, llm=None)
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert reloaded.content == "## Raw content"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
|
||||
async def test_chunks_written_to_db(
|
||||
db_session, db_search_space, make_connector_document, mocker,
|
||||
):
|
||||
"""Chunks derived from source_markdown are persisted in the DB."""
|
||||
connector_doc = make_connector_document(search_space_id=db_search_space.id)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([connector_doc])
|
||||
document = prepared[0]
|
||||
document_id = document.id
|
||||
|
||||
await service.index(document, connector_doc, llm=mocker.Mock())
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Chunk).filter(Chunk.document_id == document_id)
|
||||
)
|
||||
chunks = result.scalars().all()
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].content == "Test chunk content."
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
|
||||
async def test_embedding_written_to_db(
|
||||
db_session, db_search_space, make_connector_document, mocker,
|
||||
):
|
||||
"""Document embedding vector is persisted in the DB after indexing."""
|
||||
connector_doc = make_connector_document(search_space_id=db_search_space.id)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([connector_doc])
|
||||
document = prepared[0]
|
||||
document_id = document.id
|
||||
|
||||
await service.index(document, connector_doc, llm=mocker.Mock())
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert reloaded.embedding is not None
|
||||
assert len(reloaded.embedding) == 1024
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
|
||||
async def test_updated_at_advances_after_indexing(
|
||||
db_session, db_search_space, make_connector_document, mocker,
|
||||
):
|
||||
"""updated_at timestamp is later after indexing than it was at prepare time."""
|
||||
connector_doc = make_connector_document(search_space_id=db_search_space.id)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([connector_doc])
|
||||
document = prepared[0]
|
||||
document_id = document.id
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
||||
updated_at_pending = result.scalars().first().updated_at
|
||||
|
||||
await service.index(document, connector_doc, llm=mocker.Mock())
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
||||
updated_at_ready = result.scalars().first().updated_at
|
||||
|
||||
assert updated_at_ready > updated_at_pending
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
|
||||
async def test_no_llm_falls_back_to_source_markdown(
|
||||
db_session, db_search_space, make_connector_document,
|
||||
):
|
||||
"""When llm=None and no fallback_summary, content falls back to source_markdown."""
|
||||
connector_doc = make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
should_summarize=True,
|
||||
source_markdown="## Fallback content",
|
||||
)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([connector_doc])
|
||||
document = prepared[0]
|
||||
document_id = document.id
|
||||
|
||||
await service.index(document, connector_doc, llm=None)
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.READY)
|
||||
assert reloaded.content == "## Fallback content"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
|
||||
async def test_fallback_summary_used_when_llm_unavailable(
|
||||
db_session, db_search_space, make_connector_document,
|
||||
):
|
||||
"""fallback_summary is used as content when llm=None and should_summarize=True."""
|
||||
connector_doc = make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
should_summarize=True,
|
||||
source_markdown="## Full raw content",
|
||||
fallback_summary="Short pre-built summary.",
|
||||
)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([connector_doc])
|
||||
document_id = prepared[0].id
|
||||
|
||||
await service.index(prepared[0], connector_doc, llm=None)
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.READY)
|
||||
assert reloaded.content == "Short pre-built summary."
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
|
||||
async def test_reindex_replaces_old_chunks(
|
||||
db_session, db_search_space, make_connector_document, mocker,
|
||||
):
|
||||
"""Re-indexing a document replaces its old chunks rather than appending."""
|
||||
connector_doc = make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
source_markdown="## v1",
|
||||
)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([connector_doc])
|
||||
document = prepared[0]
|
||||
document_id = document.id
|
||||
|
||||
await service.index(document, connector_doc, llm=mocker.Mock())
|
||||
|
||||
updated_doc = make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
source_markdown="## v2",
|
||||
)
|
||||
re_prepared = await service.prepare_for_indexing([updated_doc])
|
||||
await service.index(re_prepared[0], updated_doc, llm=mocker.Mock())
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Chunk).filter(Chunk.document_id == document_id)
|
||||
)
|
||||
chunks = result.scalars().all()
|
||||
|
||||
assert len(chunks) == 1
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarize_raises", "patched_embed_text", "patched_chunk_text")
|
||||
async def test_llm_error_sets_status_failed(
|
||||
db_session, db_search_space, make_connector_document, mocker,
|
||||
):
|
||||
"""Document status is FAILED when the LLM raises during indexing."""
|
||||
connector_doc = make_connector_document(search_space_id=db_search_space.id)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([connector_doc])
|
||||
document = prepared[0]
|
||||
document_id = document.id
|
||||
|
||||
await service.index(document, connector_doc, llm=mocker.Mock())
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.FAILED)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarize_raises", "patched_embed_text", "patched_chunk_text")
|
||||
async def test_llm_error_leaves_no_partial_data(
|
||||
db_session, db_search_space, make_connector_document, mocker,
|
||||
):
|
||||
"""A failed indexing attempt leaves no partial embedding or chunks in the DB."""
|
||||
connector_doc = make_connector_document(search_space_id=db_search_space.id)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([connector_doc])
|
||||
document = prepared[0]
|
||||
document_id = document.id
|
||||
|
||||
await service.index(document, connector_doc, llm=mocker.Mock())
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert reloaded.embedding is None
|
||||
assert reloaded.content == "Pending..."
|
||||
|
||||
chunks_result = await db_session.execute(
|
||||
select(Chunk).filter(Chunk.document_id == document_id)
|
||||
)
|
||||
assert chunks_result.scalars().all() == []
|
||||
|
|
@ -0,0 +1,377 @@
|
|||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.db import Document, DocumentStatus
|
||||
from app.indexing_pipeline.document_hashing import compute_content_hash as real_compute_content_hash
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
async def test_new_document_is_persisted_with_pending_status(
|
||||
db_session, db_search_space, make_connector_document
|
||||
):
|
||||
"""A new document is created in the DB with PENDING status and correct markdown."""
|
||||
doc = make_connector_document(search_space_id=db_search_space.id)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
results = await service.prepare_for_indexing([doc])
|
||||
|
||||
assert len(results) == 1
|
||||
document_id = results[0].id
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert reloaded is not None
|
||||
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.PENDING)
|
||||
assert reloaded.source_markdown == doc.source_markdown
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
|
||||
async def test_unchanged_ready_document_is_skipped(
|
||||
db_session, db_search_space, make_connector_document, mocker,
|
||||
):
|
||||
"""A READY document with unchanged content is not returned for re-indexing."""
|
||||
doc = make_connector_document(search_space_id=db_search_space.id)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
# Index fully so the document reaches ready state
|
||||
prepared = await service.prepare_for_indexing([doc])
|
||||
await service.index(prepared[0], doc, llm=mocker.Mock())
|
||||
|
||||
# Same content on the next run — a ready document must be skipped
|
||||
results = await service.prepare_for_indexing([doc])
|
||||
|
||||
assert results == []
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
|
||||
async def test_title_only_change_updates_title_in_db(
|
||||
db_session, db_search_space, make_connector_document, mocker,
|
||||
):
|
||||
"""A title-only change updates the DB title without re-queuing the document."""
|
||||
original = make_connector_document(search_space_id=db_search_space.id, title="Original Title")
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([original])
|
||||
document_id = prepared[0].id
|
||||
await service.index(prepared[0], original, llm=mocker.Mock())
|
||||
|
||||
renamed = make_connector_document(search_space_id=db_search_space.id, title="Updated Title")
|
||||
results = await service.prepare_for_indexing([renamed])
|
||||
|
||||
assert results == []
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert reloaded.title == "Updated Title"
|
||||
|
||||
|
||||
async def test_changed_content_is_returned_for_reprocessing(
|
||||
db_session, db_search_space, make_connector_document
|
||||
):
|
||||
"""A document with changed content is returned for re-indexing with updated markdown."""
|
||||
original = make_connector_document(search_space_id=db_search_space.id, source_markdown="## v1")
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
first = await service.prepare_for_indexing([original])
|
||||
original_id = first[0].id
|
||||
|
||||
updated = make_connector_document(search_space_id=db_search_space.id, source_markdown="## v2")
|
||||
results = await service.prepare_for_indexing([updated])
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].id == original_id
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.id == original_id))
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert reloaded.source_markdown == "## v2"
|
||||
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.PENDING)
|
||||
|
||||
|
||||
async def test_all_documents_in_batch_are_persisted(
|
||||
db_session, db_search_space, make_connector_document
|
||||
):
|
||||
"""All documents in a batch are persisted and returned."""
|
||||
docs = [
|
||||
make_connector_document(search_space_id=db_search_space.id, unique_id="id-1", title="Doc 1", source_markdown="## Content 1"),
|
||||
make_connector_document(search_space_id=db_search_space.id, unique_id="id-2", title="Doc 2", source_markdown="## Content 2"),
|
||||
make_connector_document(search_space_id=db_search_space.id, unique_id="id-3", title="Doc 3", source_markdown="## Content 3"),
|
||||
]
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
results = await service.prepare_for_indexing(docs)
|
||||
|
||||
assert len(results) == 3
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.search_space_id == db_search_space.id))
|
||||
rows = result.scalars().all()
|
||||
|
||||
assert len(rows) == 3
|
||||
|
||||
|
||||
async def test_duplicate_in_batch_is_persisted_once(
|
||||
db_session, db_search_space, make_connector_document
|
||||
):
|
||||
"""The same document passed twice in a batch is only persisted once."""
|
||||
doc = make_connector_document(search_space_id=db_search_space.id)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
results = await service.prepare_for_indexing([doc, doc])
|
||||
|
||||
assert len(results) == 1
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.search_space_id == db_search_space.id))
|
||||
rows = result.scalars().all()
|
||||
|
||||
assert len(rows) == 1
|
||||
|
||||
|
||||
async def test_created_by_id_is_persisted(
|
||||
db_session, db_user, db_search_space, make_connector_document
|
||||
):
|
||||
"""created_by_id from the connector document is persisted on the DB row."""
|
||||
doc = make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
created_by_id=str(db_user.id),
|
||||
)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
results = await service.prepare_for_indexing([doc])
|
||||
document_id = results[0].id
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert str(reloaded.created_by_id) == str(db_user.id)
|
||||
|
||||
|
||||
async def test_metadata_is_updated_when_content_changes(
|
||||
db_session, db_search_space, make_connector_document
|
||||
):
|
||||
"""document_metadata is overwritten with the latest metadata when content changes."""
|
||||
original = make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
source_markdown="## v1",
|
||||
metadata={"status": "in_progress"},
|
||||
)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
first = await service.prepare_for_indexing([original])
|
||||
document_id = first[0].id
|
||||
|
||||
updated = make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
source_markdown="## v2",
|
||||
metadata={"status": "done"},
|
||||
)
|
||||
await service.prepare_for_indexing([updated])
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert reloaded.document_metadata == {"status": "done"}
|
||||
|
||||
|
||||
async def test_updated_at_advances_when_title_only_changes(
|
||||
db_session, db_search_space, make_connector_document
|
||||
):
|
||||
"""updated_at advances even when only the title changes."""
|
||||
original = make_connector_document(search_space_id=db_search_space.id, title="Old Title")
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
first = await service.prepare_for_indexing([original])
|
||||
document_id = first[0].id
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
||||
updated_at_v1 = result.scalars().first().updated_at
|
||||
|
||||
renamed = make_connector_document(search_space_id=db_search_space.id, title="New Title")
|
||||
await service.prepare_for_indexing([renamed])
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
||||
updated_at_v2 = result.scalars().first().updated_at
|
||||
|
||||
assert updated_at_v2 > updated_at_v1
|
||||
|
||||
|
||||
async def test_updated_at_advances_when_content_changes(
|
||||
db_session, db_search_space, make_connector_document
|
||||
):
|
||||
"""updated_at advances when document content changes."""
|
||||
original = make_connector_document(search_space_id=db_search_space.id, source_markdown="## v1")
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
first = await service.prepare_for_indexing([original])
|
||||
document_id = first[0].id
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
||||
updated_at_v1 = result.scalars().first().updated_at
|
||||
|
||||
updated = make_connector_document(search_space_id=db_search_space.id, source_markdown="## v2")
|
||||
await service.prepare_for_indexing([updated])
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
||||
updated_at_v2 = result.scalars().first().updated_at
|
||||
|
||||
assert updated_at_v2 > updated_at_v1
|
||||
|
||||
|
||||
async def test_same_content_from_different_source_skipped_in_single_batch(
|
||||
db_session, db_search_space, make_connector_document
|
||||
):
|
||||
"""Two documents with identical content in the same batch result in only one being persisted."""
|
||||
first = make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
unique_id="source-a",
|
||||
source_markdown="## Shared content",
|
||||
)
|
||||
second = make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
unique_id="source-b",
|
||||
source_markdown="## Shared content",
|
||||
)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
results = await service.prepare_for_indexing([first, second])
|
||||
|
||||
assert len(results) == 1
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.search_space_id == db_search_space.id)
|
||||
)
|
||||
assert len(result.scalars().all()) == 1
|
||||
|
||||
|
||||
async def test_same_content_from_different_source_is_skipped(
|
||||
db_session, db_search_space, make_connector_document
|
||||
):
|
||||
"""A document with content identical to an already-indexed document is skipped."""
|
||||
first = make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
unique_id="source-a",
|
||||
source_markdown="## Shared content",
|
||||
)
|
||||
second = make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
unique_id="source-b",
|
||||
source_markdown="## Shared content",
|
||||
)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
await service.prepare_for_indexing([first])
|
||||
results = await service.prepare_for_indexing([second])
|
||||
|
||||
assert results == []
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.search_space_id == db_search_space.id)
|
||||
)
|
||||
assert len(result.scalars().all()) == 1
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarize_raises", "patched_embed_text", "patched_chunk_text")
|
||||
async def test_failed_document_with_unchanged_content_is_requeued(
|
||||
db_session, db_search_space, make_connector_document, mocker,
|
||||
):
|
||||
"""A FAILED document with unchanged content is re-queued as PENDING on the next run."""
|
||||
doc = make_connector_document(search_space_id=db_search_space.id)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
# First run: document is created and indexing crashes → status = failed
|
||||
prepared = await service.prepare_for_indexing([doc])
|
||||
document_id = prepared[0].id
|
||||
await service.index(prepared[0], doc, llm=mocker.Mock())
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
||||
assert DocumentStatus.is_state(result.scalars().first().status, DocumentStatus.FAILED)
|
||||
|
||||
# Next run: same content, pipeline must re-queue the failed document
|
||||
results = await service.prepare_for_indexing([doc])
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].id == document_id
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
||||
assert DocumentStatus.is_state(result.scalars().first().status, DocumentStatus.PENDING)
|
||||
|
||||
|
||||
async def test_title_and_content_change_updates_both_and_returns_document(
|
||||
db_session, db_search_space, make_connector_document
|
||||
):
|
||||
"""When both title and content change, both are updated and the document is returned for re-indexing."""
|
||||
original = make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
title="Original Title",
|
||||
source_markdown="## v1",
|
||||
)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
first = await service.prepare_for_indexing([original])
|
||||
original_id = first[0].id
|
||||
|
||||
updated = make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
title="Updated Title",
|
||||
source_markdown="## v2",
|
||||
)
|
||||
results = await service.prepare_for_indexing([updated])
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].id == original_id
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.id == original_id))
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert reloaded.title == "Updated Title"
|
||||
assert reloaded.source_markdown == "## v2"
|
||||
|
||||
|
||||
|
||||
async def test_one_bad_document_in_batch_does_not_prevent_others_from_being_persisted(
|
||||
db_session, db_search_space, make_connector_document, monkeypatch,
|
||||
):
|
||||
"""
|
||||
A per-document error during prepare_for_indexing must be isolated.
|
||||
The two valid documents around the failing one must still be persisted.
|
||||
"""
|
||||
docs = [
|
||||
make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
unique_id="good-1",
|
||||
source_markdown="## Good doc 1",
|
||||
),
|
||||
make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
unique_id="will-fail",
|
||||
source_markdown="## Bad doc",
|
||||
),
|
||||
make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
unique_id="good-2",
|
||||
source_markdown="## Good doc 2",
|
||||
),
|
||||
]
|
||||
|
||||
def compute_content_hash_with_error(doc):
|
||||
if doc.unique_id == "will-fail":
|
||||
raise RuntimeError("Simulated per-document failure")
|
||||
return real_compute_content_hash(doc)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.indexing_pipeline_service.compute_content_hash",
|
||||
compute_content_hash_with_error,
|
||||
)
|
||||
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
results = await service.prepare_for_indexing(docs)
|
||||
|
||||
assert len(results) == 2
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.search_space_id == db_search_space.id)
|
||||
)
|
||||
assert len(result.scalars().all()) == 2
|
||||
0
surfsense_backend/tests/unit/__init__.py
Normal file
0
surfsense_backend/tests/unit/__init__.py
Normal file
0
surfsense_backend/tests/unit/adapters/__init__.py
Normal file
0
surfsense_backend/tests/unit/adapters/__init__.py
Normal file
33
surfsense_backend/tests/unit/indexing_pipeline/conftest.py
Normal file
33
surfsense_backend/tests/unit/indexing_pipeline/conftest.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_summarizer_chain(monkeypatch):
|
||||
chain = MagicMock()
|
||||
chain.ainvoke = AsyncMock(return_value=MagicMock(content="The summary."))
|
||||
|
||||
template = MagicMock()
|
||||
template.__or__ = MagicMock(return_value=chain)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.document_summarizer.SUMMARY_PROMPT_TEMPLATE",
|
||||
template,
|
||||
)
|
||||
return chain
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_chunker_instance(monkeypatch):
|
||||
mock = MagicMock()
|
||||
mock.chunk.return_value = [MagicMock(text="prose chunk")]
|
||||
monkeypatch.setattr("app.indexing_pipeline.document_chunker.config.chunker_instance", mock)
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_code_chunker_instance(monkeypatch):
|
||||
mock = MagicMock()
|
||||
mock.chunk.return_value = [MagicMock(text="code chunk")]
|
||||
monkeypatch.setattr("app.indexing_pipeline.document_chunker.config.code_chunker_instance", mock)
|
||||
return mock
|
||||
|
|
@ -0,0 +1,112 @@
|
|||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.db import DocumentType
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
|
||||
|
||||
def test_valid_document_created_with_required_fields():
|
||||
"""All optional fields default correctly when only required fields are supplied."""
|
||||
doc = ConnectorDocument(
|
||||
title="Task",
|
||||
source_markdown="## Task\n\nSome content.",
|
||||
unique_id="task-1",
|
||||
document_type=DocumentType.CLICKUP_CONNECTOR,
|
||||
search_space_id=1,
|
||||
connector_id=42,
|
||||
created_by_id="00000000-0000-0000-0000-000000000001",
|
||||
)
|
||||
assert doc.should_summarize is True
|
||||
assert doc.should_use_code_chunker is False
|
||||
assert doc.metadata == {}
|
||||
assert doc.connector_id == 42
|
||||
assert doc.created_by_id == "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
|
||||
def test_omitting_created_by_id_raises():
|
||||
"""Omitting created_by_id raises a validation error."""
|
||||
with pytest.raises(ValidationError):
|
||||
ConnectorDocument(
|
||||
title="Task",
|
||||
source_markdown="## Content",
|
||||
unique_id="task-1",
|
||||
document_type=DocumentType.CLICKUP_CONNECTOR,
|
||||
search_space_id=1,
|
||||
connector_id=42,
|
||||
)
|
||||
|
||||
|
||||
def test_empty_source_markdown_raises():
|
||||
"""Empty source_markdown raises a validation error."""
|
||||
with pytest.raises(ValidationError):
|
||||
ConnectorDocument(
|
||||
title="Task",
|
||||
source_markdown="",
|
||||
unique_id="task-1",
|
||||
document_type=DocumentType.CLICKUP_CONNECTOR,
|
||||
search_space_id=1,
|
||||
)
|
||||
|
||||
|
||||
def test_whitespace_only_source_markdown_raises():
|
||||
"""Whitespace-only source_markdown raises a validation error."""
|
||||
with pytest.raises(ValidationError):
|
||||
ConnectorDocument(
|
||||
title="Task",
|
||||
source_markdown=" \n\t ",
|
||||
unique_id="task-1",
|
||||
document_type=DocumentType.CLICKUP_CONNECTOR,
|
||||
search_space_id=1,
|
||||
)
|
||||
|
||||
|
||||
def test_empty_title_raises():
|
||||
"""Empty title raises a validation error."""
|
||||
with pytest.raises(ValidationError):
|
||||
ConnectorDocument(
|
||||
title="",
|
||||
source_markdown="## Content",
|
||||
unique_id="task-1",
|
||||
document_type=DocumentType.CLICKUP_CONNECTOR,
|
||||
search_space_id=1,
|
||||
)
|
||||
|
||||
|
||||
def test_empty_created_by_id_raises():
|
||||
"""Empty created_by_id raises a validation error."""
|
||||
with pytest.raises(ValidationError):
|
||||
ConnectorDocument(
|
||||
title="Task",
|
||||
source_markdown="## Content",
|
||||
unique_id="task-1",
|
||||
document_type=DocumentType.CLICKUP_CONNECTOR,
|
||||
search_space_id=1,
|
||||
connector_id=42,
|
||||
created_by_id="",
|
||||
)
|
||||
|
||||
|
||||
def test_zero_search_space_id_raises():
|
||||
"""search_space_id of zero raises a validation error."""
|
||||
with pytest.raises(ValidationError):
|
||||
ConnectorDocument(
|
||||
title="Task",
|
||||
source_markdown="## Content",
|
||||
unique_id="task-1",
|
||||
document_type=DocumentType.CLICKUP_CONNECTOR,
|
||||
search_space_id=0,
|
||||
connector_id=42,
|
||||
created_by_id="00000000-0000-0000-0000-000000000001",
|
||||
)
|
||||
|
||||
|
||||
def test_empty_unique_id_raises():
|
||||
"""Empty unique_id raises a validation error."""
|
||||
with pytest.raises(ValidationError):
|
||||
ConnectorDocument(
|
||||
title="Task",
|
||||
source_markdown="## Content",
|
||||
unique_id="",
|
||||
document_type=DocumentType.CLICKUP_CONNECTOR,
|
||||
search_space_id=1,
|
||||
)
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
import pytest
|
||||
|
||||
from app.indexing_pipeline.document_chunker import chunk_text
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_chunker_instance", "patched_code_chunker_instance")
|
||||
def test_uses_code_chunker_when_flag_is_true():
|
||||
"""Code chunker is selected when use_code_chunker=True."""
|
||||
result = chunk_text("def foo(): pass", use_code_chunker=True)
|
||||
|
||||
assert result == ["code chunk"]
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_chunker_instance", "patched_code_chunker_instance")
|
||||
def test_uses_default_chunker_when_flag_is_false():
|
||||
"""Default prose chunker is selected when use_code_chunker=False."""
|
||||
result = chunk_text("Some prose text.", use_code_chunker=False)
|
||||
|
||||
assert result == ["prose chunk"]
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
import pytest
|
||||
|
||||
from app.db import DocumentType
|
||||
from app.indexing_pipeline.document_hashing import compute_content_hash, compute_unique_identifier_hash
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_different_unique_id_produces_different_hash(make_connector_document):
|
||||
"""Two documents with different unique_ids produce different identifier hashes."""
|
||||
doc_a = make_connector_document(unique_id="id-001")
|
||||
doc_b = make_connector_document(unique_id="id-002")
|
||||
assert compute_unique_identifier_hash(doc_a) != compute_unique_identifier_hash(doc_b)
|
||||
|
||||
|
||||
def test_different_search_space_produces_different_identifier_hash(make_connector_document):
|
||||
"""Same document in different search spaces produces different identifier hashes."""
|
||||
doc_a = make_connector_document(search_space_id=1)
|
||||
doc_b = make_connector_document(search_space_id=2)
|
||||
assert compute_unique_identifier_hash(doc_a) != compute_unique_identifier_hash(doc_b)
|
||||
|
||||
|
||||
def test_different_document_type_produces_different_identifier_hash(make_connector_document):
|
||||
"""Same unique_id with different document types produces different identifier hashes."""
|
||||
doc_a = make_connector_document(document_type=DocumentType.CLICKUP_CONNECTOR)
|
||||
doc_b = make_connector_document(document_type=DocumentType.NOTION_CONNECTOR)
|
||||
assert compute_unique_identifier_hash(doc_a) != compute_unique_identifier_hash(doc_b)
|
||||
|
||||
|
||||
def test_same_content_same_space_produces_same_content_hash(make_connector_document):
|
||||
"""Identical content in the same search space always produces the same content hash."""
|
||||
doc_a = make_connector_document(source_markdown="Hello world", search_space_id=1)
|
||||
doc_b = make_connector_document(source_markdown="Hello world", search_space_id=1)
|
||||
assert compute_content_hash(doc_a) == compute_content_hash(doc_b)
|
||||
|
||||
|
||||
def test_same_content_different_space_produces_different_content_hash(make_connector_document):
|
||||
"""Identical content in different search spaces produces different content hashes."""
|
||||
doc_a = make_connector_document(source_markdown="Hello world", search_space_id=1)
|
||||
doc_b = make_connector_document(source_markdown="Hello world", search_space_id=2)
|
||||
assert compute_content_hash(doc_a) != compute_content_hash(doc_b)
|
||||
|
||||
|
||||
def test_different_content_produces_different_content_hash(make_connector_document):
|
||||
"""Different source markdown produces different content hashes."""
|
||||
doc_a = make_connector_document(source_markdown="Original content")
|
||||
doc_b = make_connector_document(source_markdown="Updated content")
|
||||
assert compute_content_hash(doc_a) != compute_content_hash(doc_b)
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from app.indexing_pipeline.document_summarizer import summarize_document
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarizer_chain")
|
||||
async def test_without_metadata_returns_raw_summary():
|
||||
"""Summarizer returns the LLM output directly when no metadata is provided."""
|
||||
result = await summarize_document("# Content", llm=MagicMock(model="gpt-4"))
|
||||
|
||||
assert result == "The summary."
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarizer_chain")
|
||||
async def test_with_metadata_includes_metadata_values_in_output():
|
||||
"""Non-empty metadata values are prepended to the summary output."""
|
||||
result = await summarize_document(
|
||||
"# Content",
|
||||
llm=MagicMock(model="gpt-4"),
|
||||
metadata={"author": "Alice", "source": "Notion"},
|
||||
)
|
||||
|
||||
assert "Alice" in result
|
||||
assert "Notion" in result
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarizer_chain")
|
||||
async def test_with_metadata_omits_empty_fields_from_output():
|
||||
"""Empty metadata fields are omitted from the summary output."""
|
||||
result = await summarize_document(
|
||||
"# Content",
|
||||
llm=MagicMock(model="gpt-4"),
|
||||
metadata={"author": "Alice", "description": ""},
|
||||
)
|
||||
|
||||
assert "Alice" in result
|
||||
assert "description" not in result.lower()
|
||||
|
||||
|
||||
6655
surfsense_backend/uv.lock
generated
6655
surfsense_backend/uv.lock
generated
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue