Merge pull request #837 from CREDO23/test-document-creation

[Refactor] Core document creation / indexing pipeline with unit and e2e tests
This commit is contained in:
Rohan Verma 2026-02-25 12:27:41 -08:00 committed by GitHub
commit 2e99f1e853
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
38 changed files with 5535 additions and 3342 deletions

112
.cursor/skills/tdd/SKILL.md Normal file
View file

@ -0,0 +1,112 @@
---
name: tdd
description: Strict Python TDD workflow using pytest (Red-Green-Refactor).
---
---
name: tdd
description: Test-driven development with red-green-refactor loop. Use when user wants to build features or fix bugs using TDD, mentions "red-green-refactor", wants integration tests, or asks for test-first development.
---
# Test-Driven Development
## Philosophy
**Core principle**: Tests should verify behavior through public interfaces, not implementation details. Code can change entirely; tests shouldn't.
**Good tests** are integration-style: they exercise real code paths through public APIs. They describe _what_ the system does, not _how_ it does it. A good test reads like a specification - "user can checkout with valid cart" tells you exactly what capability exists. These tests survive refactors because they don't care about internal structure.
**Bad tests** are coupled to implementation. They mock internal collaborators, test private methods, or verify through external means (like querying a database directly instead of using the interface). The warning sign: your test breaks when you refactor, but behavior hasn't changed. If you rename an internal function and tests fail, those tests were testing implementation, not behavior.
See [tests.md](tests.md) for examples and [mocking.md](mocking.md) for mocking guidelines.
## Anti-Pattern: Horizontal Slices
**DO NOT write all tests first, then all implementation.** This is "horizontal slicing" - treating RED as "write all tests" and GREEN as "write all code."
This produces **crap tests**:
- Tests written in bulk test _imagined_ behavior, not _actual_ behavior
- You end up testing the _shape_ of things (data structures, function signatures) rather than user-facing behavior
- Tests become insensitive to real changes - they pass when behavior breaks, fail when behavior is fine
- You outrun your headlights, committing to test structure before understanding the implementation
**Correct approach**: Vertical slices via tracer bullets. One test → one implementation → repeat. Each test responds to what you learned from the previous cycle. Because you just wrote the code, you know exactly what behavior matters and how to verify it.
```
WRONG (horizontal):
RED: test1, test2, test3, test4, test5
GREEN: impl1, impl2, impl3, impl4, impl5
RIGHT (vertical):
RED→GREEN: test1→impl1
RED→GREEN: test2→impl2
RED→GREEN: test3→impl3
...
```
## Workflow
### 1. Planning
Before writing any code:
- [ ] Confirm with user what interface changes are needed
- [ ] Confirm with user which behaviors to test (prioritize)
- [ ] Identify opportunities for [deep modules](deep-modules.md) (small interface, deep implementation)
- [ ] Design interfaces for [testability](interface-design.md)
- [ ] List the behaviors to test (not implementation steps)
- [ ] Get user approval on the plan
Ask: "What should the public interface look like? Which behaviors are most important to test?"
**You can't test everything.** Confirm with the user exactly which behaviors matter most. Focus testing effort on critical paths and complex logic, not every possible edge case.
### 2. Tracer Bullet
Write ONE test that confirms ONE thing about the system:
```
RED: Write test for first behavior → test fails
GREEN: Write minimal code to pass → test passes
```
This is your tracer bullet - proves the path works end-to-end.
### 3. Incremental Loop
For each remaining behavior:
```
RED: Write next test → fails
GREEN: Minimal code to pass → passes
```
Rules:
- One test at a time
- Only enough code to pass current test
- Don't anticipate future tests
- Keep tests focused on observable behavior
### 4. Refactor
After all tests pass, look for [refactor candidates](refactoring.md):
- [ ] Extract duplication
- [ ] Deepen modules (move complexity behind simple interfaces)
- [ ] Apply SOLID principles where natural
- [ ] Consider what new code reveals about existing code
- [ ] Run tests after each refactor step
**Never refactor while RED.** Get to GREEN first.
## Checklist Per Cycle
```
[ ] Test describes behavior, not implementation
[ ] Test uses public interface only
[ ] Test would survive internal refactor
[ ] Code is minimal for this test
[ ] No speculative features added
```

View file

@ -0,0 +1,33 @@
# Deep Modules
From "A Philosophy of Software Design":
**Deep module** = small interface + lots of implementation
```
┌─────────────────────┐
│ Small Interface │ ← Few methods, simple params
├─────────────────────┤
│ │
│ │
│ Deep Implementation│ ← Complex logic hidden
│ │
│ │
└─────────────────────┘
```
**Shallow module** = large interface + little implementation (avoid)
```
┌─────────────────────────────────┐
│ Large Interface │ ← Many methods, complex params
├─────────────────────────────────┤
│ Thin Implementation │ ← Just passes through
└─────────────────────────────────┘
```
When designing interfaces, ask:
- Can I reduce the number of methods?
- Can I simplify the parameters?
- Can I hide more complexity inside?

View file

@ -0,0 +1,33 @@
# Interface Design for Testability
Good interfaces make testing natural:
1. **Accept dependencies, don't create them**
```python
# Testable
def process_order(order, payment_gateway):
pass
# Hard to test
def process_order(order):
gateway = StripeGateway()
```
2. **Return results, don't produce side effects**
```python
# Testable
def calculate_discount(cart) -> float:
return discount
# Hard to test
def apply_discount(cart) -> None:
cart.total -= discount
```
3. **Small surface area**
* Fewer methods = fewer tests needed
* Fewer params = simpler test setup

View file

@ -0,0 +1,69 @@
# When to Mock
Mock at **system boundaries** only:
* External APIs (payment, email, etc.)
* Databases (sometimes - prefer test DB)
* Time/randomness
* File system (sometimes)
Don't mock:
* Your own classes/modules
* Internal collaborators
* Anything you control
## Designing for Mockability
At system boundaries, design interfaces that are easy to mock:
**1. Use dependency injection**
Pass external dependencies in rather than creating them internally:
```python
import os
# Easy to mock
def process_payment(order, payment_client):
return payment_client.charge(order.total)
# Hard to mock
def process_payment(order):
client = StripeClient(os.getenv("STRIPE_KEY"))
return client.charge(order.total)
```
**2. Prefer SDK-style interfaces over generic fetchers**
Create specific functions for each external operation instead of one generic function with conditional logic:
```python
import requests
# GOOD: Each function is independently mockable
class UserAPI:
def get_user(self, user_id):
return requests.get(f"/users/{user_id}")
def get_orders(self, user_id):
return requests.get(f"/users/{user_id}/orders")
def create_order(self, data):
return requests.post("/orders", json=data)
# BAD: Mocking requires conditional logic inside the mock
class GenericAPI:
def fetch(self, endpoint, method="GET", data=None):
return requests.request(method, endpoint, json=data)
```
The SDK approach means:
* Each mock returns one specific shape
* No conditional logic in test setup
* Easier to see which endpoints a test exercises
* Type safety per endpoint

View file

@ -0,0 +1,10 @@
# Refactor Candidates
After TDD cycle, look for:
- **Duplication** → Extract function/class
- **Long methods** → Break into private helpers (keep tests on public interface)
- **Shallow modules** → Combine or deepen
- **Feature envy** → Move logic to where data lives
- **Primitive obsession** → Introduce value objects
- **Existing code** the new code reveals as problematic

View file

@ -0,0 +1,60 @@
# Good and Bad Tests
## Good Tests
**Integration-style**: Test through real interfaces, not mocks of internal parts.
```python
# GOOD: Tests observable behavior
def test_user_can_checkout_with_valid_cart():
cart = create_cart()
cart.add(product)
result = checkout(cart, payment_method)
assert result.status == "confirmed"
```
Characteristics:
* Tests behavior users/callers care about
* Uses public API only
* Survives internal refactors
* Describes WHAT, not HOW
* One logical assertion per test
## Bad Tests
**Implementation-detail tests**: Coupled to internal structure.
```python
# BAD: Tests implementation details
def test_checkout_calls_payment_service_process():
mock_payment = MagicMock()
checkout(cart, mock_payment)
mock_payment.process.assert_called_with(cart.total)
```
Red flags:
* Mocking internal collaborators
* Testing private methods
* Asserting on call counts/order
* Test breaks when refactoring without behavior change
* Test name describes HOW not WHAT
* Verifying through external means instead of interface
```python
# BAD: Bypasses interface to verify
def test_create_user_saves_to_database():
create_user({"name": "Alice"})
row = db.query("SELECT * FROM users WHERE name = ?", ["Alice"])
assert row is not None
# GOOD: Verifies through interface
def test_create_user_makes_user_retrievable():
user = create_user({"name": "Alice"})
retrieved = get_user(user.id)
assert retrieved.name == "Alice"
```

View file

@ -0,0 +1,46 @@
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import DocumentStatus, DocumentType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
async def index_uploaded_file(
markdown_content: str,
filename: str,
etl_service: str,
search_space_id: int,
user_id: str,
session: AsyncSession,
llm,
) -> None:
connector_doc = ConnectorDocument(
title=filename,
source_markdown=markdown_content,
unique_id=filename,
document_type=DocumentType.FILE,
search_space_id=search_space_id,
created_by_id=user_id,
connector_id=None,
should_summarize=True,
should_use_code_chunker=False,
fallback_summary=markdown_content[:4000],
metadata={
"FILE_NAME": filename,
"ETL_SERVICE": etl_service,
},
)
service = IndexingPipelineService(session)
documents = await service.prepare_for_indexing([connector_doc])
if not documents:
raise RuntimeError("prepare_for_indexing returned no documents")
indexed = await service.index(documents[0], connector_doc, llm)
if not DocumentStatus.is_state(indexed.status, DocumentStatus.READY):
raise RuntimeError(indexed.status.get("reason", "Indexing failed"))
indexed.content_needs_reindexing = False
await session.commit()

View file

@ -0,0 +1,25 @@
from pydantic import BaseModel, Field, field_validator
from app.db import DocumentType
class ConnectorDocument(BaseModel):
"""Canonical data transfer object produced by connector adapters and consumed by the indexing pipeline."""
title: str
source_markdown: str
unique_id: str
document_type: DocumentType
search_space_id: int = Field(gt=0)
should_summarize: bool = True
should_use_code_chunker: bool = False
fallback_summary: str | None = None
metadata: dict = {}
connector_id: int | None = None
created_by_id: str
@field_validator("title", "source_markdown", "unique_id", "created_by_id")
@classmethod
def not_empty(cls, v: str, info) -> str:
if not v.strip():
raise ValueError(f"{info.field_name} must not be empty or whitespace")
return v

View file

@ -0,0 +1,7 @@
from app.config import config
def chunk_text(text: str, use_code_chunker: bool = False) -> list[str]:
"""Chunk a text string using the configured chunker and return the chunk texts."""
chunker = config.code_chunker_instance if use_code_chunker else config.chunker_instance
return [c.text for c in chunker.chunk(text)]

View file

@ -0,0 +1,6 @@
from app.config import config
def embed_text(text: str) -> list[float]:
"""Embed a single text string using the configured embedding model."""
return config.embedding_model_instance.embed(text)

View file

@ -0,0 +1,15 @@
import hashlib
from app.indexing_pipeline.connector_document import ConnectorDocument
def compute_unique_identifier_hash(doc: ConnectorDocument) -> str:
"""Return a stable SHA-256 hash identifying a document by its source identity."""
combined = f"{doc.document_type.value}:{doc.unique_id}:{doc.search_space_id}"
return hashlib.sha256(combined.encode("utf-8")).hexdigest()
def compute_content_hash(doc: ConnectorDocument) -> str:
"""Return a SHA-256 hash of the document's content scoped to its search space."""
combined = f"{doc.search_space_id}:{doc.source_markdown}"
return hashlib.sha256(combined.encode("utf-8")).hexdigest()

View file

@ -0,0 +1,39 @@
from datetime import UTC, datetime
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import object_session
from sqlalchemy.orm.attributes import set_committed_value
from app.db import Document, DocumentStatus
async def rollback_and_persist_failure(
session: AsyncSession, document: Document, message: str
) -> None:
"""Roll back the current transaction and best-effort persist a failed status.
Called exclusively from except blocks must never raise, or the new exception
would chain with the original and mask it entirely.
"""
try:
await session.rollback()
except Exception:
return # Session is completely dead; nothing further we can do.
try:
await session.refresh(document)
document.updated_at = datetime.now(UTC)
document.status = DocumentStatus.failed(message)
await session.commit()
except Exception:
pass # Best-effort; document will be retried on the next sync.
def attach_chunks_to_document(document: Document, chunks: list) -> None:
"""Assign chunks to a document without triggering SQLAlchemy async lazy loading."""
set_committed_value(document, "chunks", chunks)
session = object_session(document)
if session is not None:
if document.id is not None:
for chunk in chunks:
chunk.document_id = document.id
session.add_all(chunks)

View file

@ -0,0 +1,28 @@
from app.prompts import SUMMARY_PROMPT_TEMPLATE
from app.utils.document_converters import optimize_content_for_context_window
async def summarize_document(source_markdown: str, llm, metadata: dict | None = None) -> str:
"""Generate a text summary of a document using an LLM, prefixed with metadata when provided."""
model_name = getattr(llm, "model", "gpt-3.5-turbo")
optimized_content = optimize_content_for_context_window(
source_markdown, metadata, model_name
)
summary_chain = SUMMARY_PROMPT_TEMPLATE | llm
content_with_metadata = (
f"<DOCUMENT><DOCUMENT_METADATA>\n\n{metadata}\n\n</DOCUMENT_METADATA>"
f"\n\n<DOCUMENT_CONTENT>\n\n{optimized_content}\n\n</DOCUMENT_CONTENT></DOCUMENT>"
)
summary_result = await summary_chain.ainvoke({"document": content_with_metadata})
summary_content = summary_result.content
if metadata:
metadata_parts = ["# DOCUMENT METADATA"]
for key, value in metadata.items():
if value:
metadata_parts.append(f"**{key.replace('_', ' ').title()}:** {value}")
metadata_section = "\n".join(metadata_parts)
return f"{metadata_section}\n\n# DOCUMENT SUMMARY\n\n{summary_content}"
return summary_content

View file

@ -0,0 +1,121 @@
from litellm.exceptions import (
APIConnectionError,
APIResponseValidationError,
AuthenticationError,
BadGatewayError,
BadRequestError,
InternalServerError,
NotFoundError,
PermissionDeniedError,
RateLimitError,
ServiceUnavailableError,
Timeout,
UnprocessableEntityError,
)
from sqlalchemy.exc import IntegrityError
# Tuples for use directly in except clauses.
RETRYABLE_LLM_ERRORS = (
RateLimitError,
Timeout,
ServiceUnavailableError,
BadGatewayError,
InternalServerError,
APIConnectionError,
)
PERMANENT_LLM_ERRORS = (
AuthenticationError,
PermissionDeniedError,
NotFoundError,
BadRequestError,
UnprocessableEntityError,
APIResponseValidationError,
)
# (LiteLLMEmbeddings, CohereEmbeddings, GeminiEmbeddings all normalize to RuntimeError).
EMBEDDING_ERRORS = (
RuntimeError, # local device failure or API backend normalization
OSError, # model files missing or corrupted (local backends)
MemoryError, # document too large for available RAM
)
class PipelineMessages:
RATE_LIMIT = "LLM rate limit exceeded. Will retry on next sync."
LLM_TIMEOUT = "LLM request timed out. Will retry on next sync."
LLM_UNAVAILABLE = "LLM service temporarily unavailable. Will retry on next sync."
LLM_BAD_GATEWAY = "LLM gateway error. Will retry on next sync."
LLM_SERVER_ERROR = "LLM internal server error. Will retry on next sync."
LLM_CONNECTION = "Could not reach the LLM service. Check network connectivity."
LLM_AUTH = "LLM authentication failed. Check your API key."
LLM_PERMISSION = "LLM request denied. Check your account permissions."
LLM_NOT_FOUND = "LLM model not found. Check your model configuration."
LLM_BAD_REQUEST = "LLM rejected the request. Document content may be invalid."
LLM_UNPROCESSABLE = "Document exceeds the LLM context window even after optimization."
LLM_RESPONSE = "LLM returned an invalid response."
EMBEDDING_FAILED = "Embedding failed. Check your embedding model configuration or service."
EMBEDDING_MODEL = "Embedding model files are missing or corrupted."
EMBEDDING_MEMORY = "Not enough memory to embed this document."
CHUNKING_OVERFLOW = "Document structure is too deeply nested to chunk."
def safe_exception_message(exc: Exception) -> str:
try:
return str(exc)
except Exception:
return "Something went wrong during indexing. Error details could not be retrieved."
def llm_retryable_message(exc: Exception) -> str:
try:
if isinstance(exc, RateLimitError):
return PipelineMessages.RATE_LIMIT
if isinstance(exc, Timeout):
return PipelineMessages.LLM_TIMEOUT
if isinstance(exc, ServiceUnavailableError):
return PipelineMessages.LLM_UNAVAILABLE
if isinstance(exc, BadGatewayError):
return PipelineMessages.LLM_BAD_GATEWAY
if isinstance(exc, InternalServerError):
return PipelineMessages.LLM_SERVER_ERROR
if isinstance(exc, APIConnectionError):
return PipelineMessages.LLM_CONNECTION
return safe_exception_message(exc)
except Exception:
return "Something went wrong when calling the LLM."
def llm_permanent_message(exc: Exception) -> str:
try:
if isinstance(exc, AuthenticationError):
return PipelineMessages.LLM_AUTH
if isinstance(exc, PermissionDeniedError):
return PipelineMessages.LLM_PERMISSION
if isinstance(exc, NotFoundError):
return PipelineMessages.LLM_NOT_FOUND
if isinstance(exc, BadRequestError):
return PipelineMessages.LLM_BAD_REQUEST
if isinstance(exc, UnprocessableEntityError):
return PipelineMessages.LLM_UNPROCESSABLE
if isinstance(exc, APIResponseValidationError):
return PipelineMessages.LLM_RESPONSE
return safe_exception_message(exc)
except Exception:
return "Something went wrong when calling the LLM."
def embedding_message(exc: Exception) -> str:
try:
if isinstance(exc, RuntimeError):
return PipelineMessages.EMBEDDING_FAILED
if isinstance(exc, OSError):
return PipelineMessages.EMBEDDING_MODEL
if isinstance(exc, MemoryError):
return PipelineMessages.EMBEDDING_MEMORY
return safe_exception_message(exc)
except Exception:
return "Something went wrong when generating the embedding."

View file

@ -0,0 +1,237 @@
import contextlib
from datetime import UTC, datetime
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import Chunk, Document, DocumentStatus
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_chunker import chunk_text
from app.indexing_pipeline.document_embedder import embed_text
from app.indexing_pipeline.document_hashing import (
compute_content_hash,
compute_unique_identifier_hash,
)
from app.indexing_pipeline.document_persistence import (
attach_chunks_to_document,
rollback_and_persist_failure,
)
from app.indexing_pipeline.document_summarizer import summarize_document
from app.indexing_pipeline.exceptions import (
EMBEDDING_ERRORS,
PERMANENT_LLM_ERRORS,
RETRYABLE_LLM_ERRORS,
IntegrityError,
PipelineMessages,
embedding_message,
llm_permanent_message,
llm_retryable_message,
safe_exception_message,
)
from app.indexing_pipeline.pipeline_logger import (
PipelineLogContext,
log_batch_aborted,
log_chunking_overflow,
log_doc_skipped_unknown,
log_document_queued,
log_document_requeued,
log_document_updated,
log_embedding_error,
log_index_started,
log_index_success,
log_permanent_llm_error,
log_race_condition,
log_retryable_llm_error,
log_unexpected_error,
)
class IndexingPipelineService:
"""Single pipeline for indexing connector documents. All connectors use this service."""
def __init__(self, session: AsyncSession) -> None:
self.session = session
async def prepare_for_indexing(
self, connector_docs: list[ConnectorDocument]
) -> list[Document]:
"""
Persist new documents and detect changes, returning only those that need indexing.
"""
documents = []
seen_hashes: set[str] = set()
batch_ctx = PipelineLogContext(
connector_id=connector_docs[0].connector_id if connector_docs else 0,
search_space_id=connector_docs[0].search_space_id if connector_docs else 0,
unique_id="batch",
)
for connector_doc in connector_docs:
ctx = PipelineLogContext(
connector_id=connector_doc.connector_id,
search_space_id=connector_doc.search_space_id,
unique_id=connector_doc.unique_id,
)
try:
unique_identifier_hash = compute_unique_identifier_hash(connector_doc)
content_hash = compute_content_hash(connector_doc)
if unique_identifier_hash in seen_hashes:
continue
seen_hashes.add(unique_identifier_hash)
result = await self.session.execute(
select(Document).filter(
Document.unique_identifier_hash == unique_identifier_hash
)
)
existing = result.scalars().first()
if existing is not None:
if existing.content_hash == content_hash:
if existing.title != connector_doc.title:
existing.title = connector_doc.title
existing.updated_at = datetime.now(UTC)
if not DocumentStatus.is_state(
existing.status, DocumentStatus.READY
):
existing.status = DocumentStatus.pending()
existing.updated_at = datetime.now(UTC)
documents.append(existing)
log_document_requeued(ctx)
continue
existing.title = connector_doc.title
existing.content_hash = content_hash
existing.source_markdown = connector_doc.source_markdown
existing.document_metadata = connector_doc.metadata
existing.updated_at = datetime.now(UTC)
existing.status = DocumentStatus.pending()
documents.append(existing)
log_document_updated(ctx)
continue
duplicate = await self.session.execute(
select(Document).filter(Document.content_hash == content_hash)
)
if duplicate.scalars().first() is not None:
continue
document = Document(
title=connector_doc.title,
document_type=connector_doc.document_type,
content="Pending...",
content_hash=content_hash,
unique_identifier_hash=unique_identifier_hash,
source_markdown=connector_doc.source_markdown,
document_metadata=connector_doc.metadata,
search_space_id=connector_doc.search_space_id,
connector_id=connector_doc.connector_id,
created_by_id=connector_doc.created_by_id,
updated_at=datetime.now(UTC),
status=DocumentStatus.pending(),
)
self.session.add(document)
documents.append(document)
log_document_queued(ctx)
except Exception as e:
log_doc_skipped_unknown(ctx, e)
try:
await self.session.commit()
return documents
except IntegrityError:
# A concurrent worker committed a document with the same content_hash
# or unique_identifier_hash between our check and our INSERT.
# The document already exists — roll back and let the next sync run handle it.
log_race_condition(batch_ctx)
await self.session.rollback()
return []
except Exception as e:
log_batch_aborted(batch_ctx, e)
await self.session.rollback()
return []
async def index(
self, document: Document, connector_doc: ConnectorDocument, llm
) -> Document:
"""
Run summarization, embedding, and chunking for a document and persist the results.
"""
ctx = PipelineLogContext(
connector_id=connector_doc.connector_id,
search_space_id=connector_doc.search_space_id,
unique_id=connector_doc.unique_id,
doc_id=document.id,
)
try:
log_index_started(ctx)
document.status = DocumentStatus.processing()
await self.session.commit()
if connector_doc.should_summarize and llm is not None:
content = await summarize_document(
connector_doc.source_markdown, llm, connector_doc.metadata
)
elif connector_doc.should_summarize and connector_doc.fallback_summary:
content = connector_doc.fallback_summary
else:
content = connector_doc.source_markdown
embedding = embed_text(content)
await self.session.execute(
delete(Chunk).where(Chunk.document_id == document.id)
)
chunks = [
Chunk(content=text, embedding=embed_text(text))
for text in chunk_text(
connector_doc.source_markdown,
use_code_chunker=connector_doc.should_use_code_chunker,
)
]
document.content = content
document.embedding = embedding
attach_chunks_to_document(document, chunks)
document.updated_at = datetime.now(UTC)
document.status = DocumentStatus.ready()
await self.session.commit()
log_index_success(ctx, chunk_count=len(chunks))
except RETRYABLE_LLM_ERRORS as e:
log_retryable_llm_error(ctx, e)
await rollback_and_persist_failure(
self.session, document, llm_retryable_message(e)
)
except PERMANENT_LLM_ERRORS as e:
log_permanent_llm_error(ctx, e)
await rollback_and_persist_failure(
self.session, document, llm_permanent_message(e)
)
except RecursionError as e:
log_chunking_overflow(ctx, e)
await rollback_and_persist_failure(
self.session, document, PipelineMessages.CHUNKING_OVERFLOW
)
except EMBEDDING_ERRORS as e:
log_embedding_error(ctx, e)
await rollback_and_persist_failure(
self.session, document, embedding_message(e)
)
except Exception as e:
log_unexpected_error(ctx, e)
await rollback_and_persist_failure(
self.session, document, safe_exception_message(e)
)
with contextlib.suppress(Exception):
await self.session.refresh(document)
return document

View file

@ -0,0 +1,118 @@
import logging
from dataclasses import dataclass
logger = logging.getLogger(__name__)
@dataclass
class PipelineLogContext:
connector_id: int | None
search_space_id: int
unique_id: str # always available from ConnectorDocument
doc_id: int | None = None # set once the DB row exists (index phase only)
class LogMessages:
# prepare_for_indexing
DOCUMENT_QUEUED = "New document queued for indexing."
DOCUMENT_UPDATED = "Document content changed, re-queued for indexing."
DOCUMENT_REQUEUED = "Stuck document re-queued for indexing."
DOC_SKIPPED_UNKNOWN = "Unexpected error — document skipped."
BATCH_ABORTED = "Fatal DB error — aborting prepare batch."
RACE_CONDITION = "Concurrent worker beat us to the commit — rolling back batch."
# index
INDEX_STARTED = "Document indexing started."
INDEX_SUCCESS = "Document indexed successfully."
LLM_RETRYABLE = "Retryable LLM error — document marked failed, will retry on next sync."
LLM_PERMANENT = "Permanent LLM error — document marked failed."
EMBEDDING_FAILED = "Embedding error — document marked failed."
CHUNKING_OVERFLOW = "Chunking overflow — document marked failed."
UNEXPECTED = "Unexpected error — document marked failed."
def _format_context(ctx: PipelineLogContext) -> str:
parts = [
f"connector_id={ctx.connector_id}",
f"search_space_id={ctx.search_space_id}",
f"unique_id={ctx.unique_id}",
]
if ctx.doc_id is not None:
parts.append(f"doc_id={ctx.doc_id}")
return " ".join(parts)
def _build_message(msg: str, ctx: PipelineLogContext, **extra) -> str:
try:
parts = [msg, _format_context(ctx)]
for key, val in extra.items():
parts.append(f"{key}={val}")
return " ".join(parts)
except Exception:
return msg
def _safe_log(level_fn, msg: str, ctx: PipelineLogContext, exc_info=None, **extra) -> None:
# Logging must never raise — a broken log call inside an except block would
# chain with the original exception and mask it entirely.
try:
message = _build_message(msg, ctx, **extra)
level_fn(message, exc_info=exc_info)
except Exception:
pass
# ── prepare_for_indexing ──────────────────────────────────────────────────────
def log_document_queued(ctx: PipelineLogContext) -> None:
_safe_log(logger.info, LogMessages.DOCUMENT_QUEUED, ctx)
def log_document_updated(ctx: PipelineLogContext) -> None:
_safe_log(logger.info, LogMessages.DOCUMENT_UPDATED, ctx)
def log_document_requeued(ctx: PipelineLogContext) -> None:
_safe_log(logger.info, LogMessages.DOCUMENT_REQUEUED, ctx)
def log_doc_skipped_unknown(ctx: PipelineLogContext, exc: Exception) -> None:
_safe_log(logger.warning, LogMessages.DOC_SKIPPED_UNKNOWN, ctx, exc_info=exc, error=exc)
def log_race_condition(ctx: PipelineLogContext) -> None:
_safe_log(logger.warning, LogMessages.RACE_CONDITION, ctx)
def log_batch_aborted(ctx: PipelineLogContext, exc: Exception) -> None:
_safe_log(logger.error, LogMessages.BATCH_ABORTED, ctx, exc_info=exc, error=exc)
# ── index ─────────────────────────────────────────────────────────────────────
def log_index_started(ctx: PipelineLogContext) -> None:
_safe_log(logger.info, LogMessages.INDEX_STARTED, ctx)
def log_index_success(ctx: PipelineLogContext, chunk_count: int) -> None:
_safe_log(logger.info, LogMessages.INDEX_SUCCESS, ctx, chunk_count=chunk_count)
def log_retryable_llm_error(ctx: PipelineLogContext, exc: Exception) -> None:
_safe_log(logger.warning, LogMessages.LLM_RETRYABLE, ctx, exc_info=exc, error=exc)
def log_permanent_llm_error(ctx: PipelineLogContext, exc: Exception) -> None:
_safe_log(logger.error, LogMessages.LLM_PERMANENT, ctx, exc_info=exc, error=exc)
def log_embedding_error(ctx: PipelineLogContext, exc: Exception) -> None:
_safe_log(logger.error, LogMessages.EMBEDDING_FAILED, ctx, exc_info=exc, error=exc)
def log_chunking_overflow(ctx: PipelineLogContext, exc: Exception) -> None:
_safe_log(logger.error, LogMessages.CHUNKING_OVERFLOW, ctx, exc_info=exc, error=exc)
def log_unexpected_error(ctx: PipelineLogContext, exc: Exception) -> None:
_safe_log(logger.error, LogMessages.UNEXPECTED, ctx, exc_info=exc, error=exc)

View file

@ -18,6 +18,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config as app_config
from app.db import Document, DocumentStatus, DocumentType, Log, Notification
from app.indexing_pipeline.adapters.file_upload_adapter import index_uploaded_file
from app.services.llm_service import get_user_long_context_llm
from app.services.notification_service import NotificationService
from app.services.task_logging_service import TaskLoggingService
@ -33,7 +34,6 @@ from .base import (
check_document_by_unique_identifier,
check_duplicate_document,
get_current_timestamp,
safe_set_chunks,
)
from .markdown_processor import add_received_markdown_file_document
@ -1863,7 +1863,7 @@ async def process_file_in_background_with_document(
)
return None
# ===== STEP 3: Generate embeddings and chunks =====
# ===== STEP 3+4: Index via pipeline =====
if notification:
await NotificationService.document_processing.notify_processing_progress(
session, notification, stage="chunking"
@ -1871,57 +1871,22 @@ async def process_file_in_background_with_document(
user_llm = await get_user_long_context_llm(session, user_id, search_space_id)
if user_llm:
document_metadata = {
"file_name": filename,
"etl_service": etl_service,
"document_type": "File Document",
}
summary_content, summary_embedding = await generate_document_summary(
markdown_content, user_llm, document_metadata
)
else:
# Fallback: use truncated content as summary
summary_content = markdown_content[:4000]
from app.config import config
summary_embedding = config.embedding_model_instance.embed(summary_content)
chunks = await create_document_chunks(markdown_content)
# ===== STEP 4: Update document to READY =====
from sqlalchemy.orm.attributes import flag_modified
document.title = filename
document.content = summary_content
document.content_hash = content_hash
document.embedding = summary_embedding
document.document_metadata = {
"FILE_NAME": filename,
"ETL_SERVICE": etl_service or "UNKNOWN",
**(document.document_metadata or {}),
}
flag_modified(document, "document_metadata")
# Use safe_set_chunks to avoid async issues
safe_set_chunks(document, chunks)
document.source_markdown = markdown_content
document.content_needs_reindexing = False
document.updated_at = get_current_timestamp()
document.status = DocumentStatus.ready() # Shows checkmark in UI
await session.commit()
await session.refresh(document)
await index_uploaded_file(
markdown_content=markdown_content,
filename=filename,
etl_service=etl_service,
search_space_id=search_space_id,
user_id=user_id,
session=session,
llm=user_llm,
)
await task_logger.log_task_success(
log_entry,
f"Successfully processed file: {filename}",
{
"document_id": document.id,
"content_hash": content_hash,
"file_type": etl_service,
"chunks_count": len(chunks),
},
)

View file

@ -72,6 +72,22 @@ dependencies = [
[dependency-groups]
dev = [
"ruff>=0.12.5",
"pytest>=8.0",
"pytest-asyncio>=0.25",
"pytest-mock>=3.14",
]
[tool.pytest.ini_options]
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "session"
asyncio_default_test_loop_scope = "session"
testpaths = ["tests"]
markers = [
"unit: pure logic tests, no DB or external services",
"integration: tests that require a real PostgreSQL database",
]
filterwarnings = [
"ignore::UserWarning:chonkie",
]
[tool.ruff]

View file

View file

@ -0,0 +1,36 @@
import pytest
from app.db import DocumentType
from app.indexing_pipeline.connector_document import ConnectorDocument
@pytest.fixture
def sample_user_id() -> str:
return "00000000-0000-0000-0000-000000000001"
@pytest.fixture
def sample_search_space_id() -> int:
return 1
@pytest.fixture
def sample_connector_id() -> int:
return 42
@pytest.fixture
def make_connector_document():
def _make(**overrides):
defaults = {
"title": "Test Document",
"source_markdown": "## Heading\n\nSome content.",
"unique_id": "test-id-001",
"document_type": DocumentType.CLICKUP_CONNECTOR,
"search_space_id": 1,
"connector_id": 1,
"created_by_id": "00000000-0000-0000-0000-000000000001",
}
defaults.update(overrides)
return ConnectorDocument(**defaults)
return _make

View file

@ -0,0 +1,164 @@
import os
import uuid
from unittest.mock import AsyncMock, MagicMock
import pytest
import pytest_asyncio
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.pool import NullPool
from app.db import Base, SearchSpace, SearchSourceConnector, SearchSourceConnectorType
from app.db import User
from app.db import DocumentType
from app.indexing_pipeline.connector_document import ConnectorDocument
_EMBEDDING_DIM = 1024 # must match the Vector() dimension used in DB column creation
_DEFAULT_TEST_DB = "postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense_test"
TEST_DATABASE_URL = os.environ.get("TEST_DATABASE_URL", _DEFAULT_TEST_DB)
@pytest_asyncio.fixture(scope="session")
async def async_engine():
engine = create_async_engine(
TEST_DATABASE_URL,
poolclass=NullPool,
echo=False,
# Required for asyncpg + savepoints: disables prepared statement cache
# to prevent "another operation is in progress" errors during savepoint rollbacks.
connect_args={"prepared_statement_cache_size": 0},
)
async with engine.begin() as conn:
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
await conn.run_sync(Base.metadata.create_all)
yield engine
# drop_all fails on circular FKs (new_chat_threads ↔ public_chat_snapshots).
# DROP SCHEMA CASCADE handles this without needing topological sort.
async with engine.begin() as conn:
await conn.execute(text("DROP SCHEMA public CASCADE"))
await conn.execute(text("CREATE SCHEMA public"))
await engine.dispose()
@pytest_asyncio.fixture
async def db_session(async_engine) -> AsyncSession:
# Bind the session to a connection that holds an outer transaction.
# join_transaction_mode="create_savepoint" makes session.commit() release
# a SAVEPOINT instead of committing the outer transaction, so the final
# transaction.rollback() undoes everything — including commits made by the
# service under test — leaving the DB clean for the next test.
async with async_engine.connect() as conn:
transaction = await conn.begin()
async with AsyncSession(
bind=conn,
expire_on_commit=False,
join_transaction_mode="create_savepoint",
) as session:
yield session
await transaction.rollback()
@pytest_asyncio.fixture
async def db_user(db_session: AsyncSession) -> User:
user = User(
id=uuid.uuid4(),
email="test@surfsense.net",
hashed_password="hashed",
is_active=True,
is_superuser=False,
is_verified=True,
)
db_session.add(user)
await db_session.flush()
return user
@pytest_asyncio.fixture
async def db_connector(db_session: AsyncSession, db_user: User, db_search_space: "SearchSpace") -> SearchSourceConnector:
connector = SearchSourceConnector(
name="Test Connector",
connector_type=SearchSourceConnectorType.CLICKUP_CONNECTOR,
config={},
search_space_id=db_search_space.id,
user_id=db_user.id,
)
db_session.add(connector)
await db_session.flush()
return connector
@pytest_asyncio.fixture
async def db_search_space(db_session: AsyncSession, db_user: User) -> SearchSpace:
space = SearchSpace(
name="Test Space",
user_id=db_user.id,
)
db_session.add(space)
await db_session.flush()
return space
@pytest.fixture
def patched_summarize(monkeypatch) -> AsyncMock:
mock = AsyncMock(return_value="Mocked summary.")
monkeypatch.setattr(
"app.indexing_pipeline.indexing_pipeline_service.summarize_document",
mock,
)
return mock
@pytest.fixture
def patched_summarize_raises(monkeypatch) -> AsyncMock:
mock = AsyncMock(side_effect=RuntimeError("LLM unavailable"))
monkeypatch.setattr(
"app.indexing_pipeline.indexing_pipeline_service.summarize_document",
mock,
)
return mock
@pytest.fixture
def patched_embed_text(monkeypatch) -> MagicMock:
mock = MagicMock(return_value=[0.1] * _EMBEDDING_DIM)
monkeypatch.setattr(
"app.indexing_pipeline.indexing_pipeline_service.embed_text",
mock,
)
return mock
@pytest.fixture
def patched_chunk_text(monkeypatch) -> MagicMock:
mock = MagicMock(return_value=["Test chunk content."])
monkeypatch.setattr(
"app.indexing_pipeline.indexing_pipeline_service.chunk_text",
mock,
)
return mock
@pytest.fixture
def make_connector_document(db_connector, db_user):
"""Integration-scoped override: uses real DB connector and user IDs."""
def _make(**overrides):
defaults = {
"title": "Test Document",
"source_markdown": "## Heading\n\nSome content.",
"unique_id": "test-id-001",
"document_type": DocumentType.CLICKUP_CONNECTOR,
"search_space_id": db_connector.search_space_id,
"connector_id": db_connector.id,
"created_by_id": str(db_user.id),
}
defaults.update(overrides)
return ConnectorDocument(**defaults)
return _make

View file

@ -0,0 +1,91 @@
import pytest
from sqlalchemy import select
from app.db import Chunk, Document, DocumentStatus
from app.indexing_pipeline.adapters.file_upload_adapter import index_uploaded_file
pytestmark = pytest.mark.integration
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
async def test_sets_status_ready(db_session, db_search_space, db_user, mocker):
"""Document status is READY after successful indexing."""
await index_uploaded_file(
markdown_content="## Hello\n\nSome content.",
filename="test.pdf",
etl_service="UNSTRUCTURED",
search_space_id=db_search_space.id,
user_id=str(db_user.id),
session=db_session,
llm=mocker.Mock(),
)
result = await db_session.execute(
select(Document).filter(Document.search_space_id == db_search_space.id)
)
document = result.scalars().first()
assert DocumentStatus.is_state(document.status, DocumentStatus.READY)
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
async def test_content_is_summary(db_session, db_search_space, db_user, mocker):
"""Document content is set to the LLM-generated summary."""
await index_uploaded_file(
markdown_content="## Hello\n\nSome content.",
filename="test.pdf",
etl_service="UNSTRUCTURED",
search_space_id=db_search_space.id,
user_id=str(db_user.id),
session=db_session,
llm=mocker.Mock(),
)
result = await db_session.execute(
select(Document).filter(Document.search_space_id == db_search_space.id)
)
document = result.scalars().first()
assert document.content == "Mocked summary."
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
async def test_chunks_written_to_db(db_session, db_search_space, db_user, mocker):
"""Chunks derived from the source markdown are persisted in the DB."""
await index_uploaded_file(
markdown_content="## Hello\n\nSome content.",
filename="test.pdf",
etl_service="UNSTRUCTURED",
search_space_id=db_search_space.id,
user_id=str(db_user.id),
session=db_session,
llm=mocker.Mock(),
)
result = await db_session.execute(
select(Document).filter(Document.search_space_id == db_search_space.id)
)
document = result.scalars().first()
chunks_result = await db_session.execute(
select(Chunk).filter(Chunk.document_id == document.id)
)
chunks = chunks_result.scalars().all()
assert len(chunks) == 1
assert chunks[0].content == "Test chunk content."
@pytest.mark.usefixtures("patched_summarize_raises", "patched_embed_text", "patched_chunk_text")
async def test_raises_on_indexing_failure(db_session, db_search_space, db_user, mocker):
"""RuntimeError is raised when the indexing step fails so the caller can fire a failure notification."""
with pytest.raises(RuntimeError):
await index_uploaded_file(
markdown_content="## Hello\n\nSome content.",
filename="test.pdf",
etl_service="UNSTRUCTURED",
search_space_id=db_search_space.id,
user_id=str(db_user.id),
session=db_session,
llm=mocker.Mock(),
)

View file

@ -0,0 +1,266 @@
import pytest
from sqlalchemy import select
from app.db import Chunk, Document, DocumentStatus
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
pytestmark = pytest.mark.integration
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
async def test_sets_status_ready(
db_session, db_search_space, make_connector_document, mocker,
):
"""Document status is READY after successful indexing."""
connector_doc = make_connector_document(search_space_id=db_search_space.id)
service = IndexingPipelineService(session=db_session)
prepared = await service.prepare_for_indexing([connector_doc])
document = prepared[0]
document_id = document.id
await service.index(document, connector_doc, llm=mocker.Mock())
result = await db_session.execute(select(Document).filter(Document.id == document_id))
reloaded = result.scalars().first()
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.READY)
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
async def test_content_is_summary_when_should_summarize_true(
db_session, db_search_space, make_connector_document, mocker,
):
"""Document content is set to the LLM-generated summary when should_summarize=True."""
connector_doc = make_connector_document(search_space_id=db_search_space.id)
service = IndexingPipelineService(session=db_session)
prepared = await service.prepare_for_indexing([connector_doc])
document = prepared[0]
document_id = document.id
await service.index(document, connector_doc, llm=mocker.Mock())
result = await db_session.execute(select(Document).filter(Document.id == document_id))
reloaded = result.scalars().first()
assert reloaded.content == "Mocked summary."
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
async def test_content_is_source_markdown_when_should_summarize_false(
db_session, db_search_space, make_connector_document,
):
"""Document content is set to source_markdown verbatim when should_summarize=False."""
connector_doc = make_connector_document(
search_space_id=db_search_space.id,
should_summarize=False,
source_markdown="## Raw content",
)
service = IndexingPipelineService(session=db_session)
prepared = await service.prepare_for_indexing([connector_doc])
document = prepared[0]
document_id = document.id
await service.index(document, connector_doc, llm=None)
result = await db_session.execute(select(Document).filter(Document.id == document_id))
reloaded = result.scalars().first()
assert reloaded.content == "## Raw content"
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
async def test_chunks_written_to_db(
db_session, db_search_space, make_connector_document, mocker,
):
"""Chunks derived from source_markdown are persisted in the DB."""
connector_doc = make_connector_document(search_space_id=db_search_space.id)
service = IndexingPipelineService(session=db_session)
prepared = await service.prepare_for_indexing([connector_doc])
document = prepared[0]
document_id = document.id
await service.index(document, connector_doc, llm=mocker.Mock())
result = await db_session.execute(
select(Chunk).filter(Chunk.document_id == document_id)
)
chunks = result.scalars().all()
assert len(chunks) == 1
assert chunks[0].content == "Test chunk content."
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
async def test_embedding_written_to_db(
db_session, db_search_space, make_connector_document, mocker,
):
"""Document embedding vector is persisted in the DB after indexing."""
connector_doc = make_connector_document(search_space_id=db_search_space.id)
service = IndexingPipelineService(session=db_session)
prepared = await service.prepare_for_indexing([connector_doc])
document = prepared[0]
document_id = document.id
await service.index(document, connector_doc, llm=mocker.Mock())
result = await db_session.execute(select(Document).filter(Document.id == document_id))
reloaded = result.scalars().first()
assert reloaded.embedding is not None
assert len(reloaded.embedding) == 1024
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
async def test_updated_at_advances_after_indexing(
db_session, db_search_space, make_connector_document, mocker,
):
"""updated_at timestamp is later after indexing than it was at prepare time."""
connector_doc = make_connector_document(search_space_id=db_search_space.id)
service = IndexingPipelineService(session=db_session)
prepared = await service.prepare_for_indexing([connector_doc])
document = prepared[0]
document_id = document.id
result = await db_session.execute(select(Document).filter(Document.id == document_id))
updated_at_pending = result.scalars().first().updated_at
await service.index(document, connector_doc, llm=mocker.Mock())
result = await db_session.execute(select(Document).filter(Document.id == document_id))
updated_at_ready = result.scalars().first().updated_at
assert updated_at_ready > updated_at_pending
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
async def test_no_llm_falls_back_to_source_markdown(
db_session, db_search_space, make_connector_document,
):
"""When llm=None and no fallback_summary, content falls back to source_markdown."""
connector_doc = make_connector_document(
search_space_id=db_search_space.id,
should_summarize=True,
source_markdown="## Fallback content",
)
service = IndexingPipelineService(session=db_session)
prepared = await service.prepare_for_indexing([connector_doc])
document = prepared[0]
document_id = document.id
await service.index(document, connector_doc, llm=None)
result = await db_session.execute(select(Document).filter(Document.id == document_id))
reloaded = result.scalars().first()
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.READY)
assert reloaded.content == "## Fallback content"
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
async def test_fallback_summary_used_when_llm_unavailable(
db_session, db_search_space, make_connector_document,
):
"""fallback_summary is used as content when llm=None and should_summarize=True."""
connector_doc = make_connector_document(
search_space_id=db_search_space.id,
should_summarize=True,
source_markdown="## Full raw content",
fallback_summary="Short pre-built summary.",
)
service = IndexingPipelineService(session=db_session)
prepared = await service.prepare_for_indexing([connector_doc])
document_id = prepared[0].id
await service.index(prepared[0], connector_doc, llm=None)
result = await db_session.execute(select(Document).filter(Document.id == document_id))
reloaded = result.scalars().first()
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.READY)
assert reloaded.content == "Short pre-built summary."
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
async def test_reindex_replaces_old_chunks(
db_session, db_search_space, make_connector_document, mocker,
):
"""Re-indexing a document replaces its old chunks rather than appending."""
connector_doc = make_connector_document(
search_space_id=db_search_space.id,
source_markdown="## v1",
)
service = IndexingPipelineService(session=db_session)
prepared = await service.prepare_for_indexing([connector_doc])
document = prepared[0]
document_id = document.id
await service.index(document, connector_doc, llm=mocker.Mock())
updated_doc = make_connector_document(
search_space_id=db_search_space.id,
source_markdown="## v2",
)
re_prepared = await service.prepare_for_indexing([updated_doc])
await service.index(re_prepared[0], updated_doc, llm=mocker.Mock())
result = await db_session.execute(
select(Chunk).filter(Chunk.document_id == document_id)
)
chunks = result.scalars().all()
assert len(chunks) == 1
@pytest.mark.usefixtures("patched_summarize_raises", "patched_embed_text", "patched_chunk_text")
async def test_llm_error_sets_status_failed(
db_session, db_search_space, make_connector_document, mocker,
):
"""Document status is FAILED when the LLM raises during indexing."""
connector_doc = make_connector_document(search_space_id=db_search_space.id)
service = IndexingPipelineService(session=db_session)
prepared = await service.prepare_for_indexing([connector_doc])
document = prepared[0]
document_id = document.id
await service.index(document, connector_doc, llm=mocker.Mock())
result = await db_session.execute(select(Document).filter(Document.id == document_id))
reloaded = result.scalars().first()
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.FAILED)
@pytest.mark.usefixtures("patched_summarize_raises", "patched_embed_text", "patched_chunk_text")
async def test_llm_error_leaves_no_partial_data(
db_session, db_search_space, make_connector_document, mocker,
):
"""A failed indexing attempt leaves no partial embedding or chunks in the DB."""
connector_doc = make_connector_document(search_space_id=db_search_space.id)
service = IndexingPipelineService(session=db_session)
prepared = await service.prepare_for_indexing([connector_doc])
document = prepared[0]
document_id = document.id
await service.index(document, connector_doc, llm=mocker.Mock())
result = await db_session.execute(select(Document).filter(Document.id == document_id))
reloaded = result.scalars().first()
assert reloaded.embedding is None
assert reloaded.content == "Pending..."
chunks_result = await db_session.execute(
select(Chunk).filter(Chunk.document_id == document_id)
)
assert chunks_result.scalars().all() == []

View file

@ -0,0 +1,377 @@
import pytest
from sqlalchemy import select
from app.db import Document, DocumentStatus
from app.indexing_pipeline.document_hashing import compute_content_hash as real_compute_content_hash
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
pytestmark = pytest.mark.integration
async def test_new_document_is_persisted_with_pending_status(
db_session, db_search_space, make_connector_document
):
"""A new document is created in the DB with PENDING status and correct markdown."""
doc = make_connector_document(search_space_id=db_search_space.id)
service = IndexingPipelineService(session=db_session)
results = await service.prepare_for_indexing([doc])
assert len(results) == 1
document_id = results[0].id
result = await db_session.execute(select(Document).filter(Document.id == document_id))
reloaded = result.scalars().first()
assert reloaded is not None
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.PENDING)
assert reloaded.source_markdown == doc.source_markdown
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
async def test_unchanged_ready_document_is_skipped(
db_session, db_search_space, make_connector_document, mocker,
):
"""A READY document with unchanged content is not returned for re-indexing."""
doc = make_connector_document(search_space_id=db_search_space.id)
service = IndexingPipelineService(session=db_session)
# Index fully so the document reaches ready state
prepared = await service.prepare_for_indexing([doc])
await service.index(prepared[0], doc, llm=mocker.Mock())
# Same content on the next run — a ready document must be skipped
results = await service.prepare_for_indexing([doc])
assert results == []
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
async def test_title_only_change_updates_title_in_db(
db_session, db_search_space, make_connector_document, mocker,
):
"""A title-only change updates the DB title without re-queuing the document."""
original = make_connector_document(search_space_id=db_search_space.id, title="Original Title")
service = IndexingPipelineService(session=db_session)
prepared = await service.prepare_for_indexing([original])
document_id = prepared[0].id
await service.index(prepared[0], original, llm=mocker.Mock())
renamed = make_connector_document(search_space_id=db_search_space.id, title="Updated Title")
results = await service.prepare_for_indexing([renamed])
assert results == []
result = await db_session.execute(select(Document).filter(Document.id == document_id))
reloaded = result.scalars().first()
assert reloaded.title == "Updated Title"
async def test_changed_content_is_returned_for_reprocessing(
db_session, db_search_space, make_connector_document
):
"""A document with changed content is returned for re-indexing with updated markdown."""
original = make_connector_document(search_space_id=db_search_space.id, source_markdown="## v1")
service = IndexingPipelineService(session=db_session)
first = await service.prepare_for_indexing([original])
original_id = first[0].id
updated = make_connector_document(search_space_id=db_search_space.id, source_markdown="## v2")
results = await service.prepare_for_indexing([updated])
assert len(results) == 1
assert results[0].id == original_id
result = await db_session.execute(select(Document).filter(Document.id == original_id))
reloaded = result.scalars().first()
assert reloaded.source_markdown == "## v2"
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.PENDING)
async def test_all_documents_in_batch_are_persisted(
db_session, db_search_space, make_connector_document
):
"""All documents in a batch are persisted and returned."""
docs = [
make_connector_document(search_space_id=db_search_space.id, unique_id="id-1", title="Doc 1", source_markdown="## Content 1"),
make_connector_document(search_space_id=db_search_space.id, unique_id="id-2", title="Doc 2", source_markdown="## Content 2"),
make_connector_document(search_space_id=db_search_space.id, unique_id="id-3", title="Doc 3", source_markdown="## Content 3"),
]
service = IndexingPipelineService(session=db_session)
results = await service.prepare_for_indexing(docs)
assert len(results) == 3
result = await db_session.execute(select(Document).filter(Document.search_space_id == db_search_space.id))
rows = result.scalars().all()
assert len(rows) == 3
async def test_duplicate_in_batch_is_persisted_once(
db_session, db_search_space, make_connector_document
):
"""The same document passed twice in a batch is only persisted once."""
doc = make_connector_document(search_space_id=db_search_space.id)
service = IndexingPipelineService(session=db_session)
results = await service.prepare_for_indexing([doc, doc])
assert len(results) == 1
result = await db_session.execute(select(Document).filter(Document.search_space_id == db_search_space.id))
rows = result.scalars().all()
assert len(rows) == 1
async def test_created_by_id_is_persisted(
db_session, db_user, db_search_space, make_connector_document
):
"""created_by_id from the connector document is persisted on the DB row."""
doc = make_connector_document(
search_space_id=db_search_space.id,
created_by_id=str(db_user.id),
)
service = IndexingPipelineService(session=db_session)
results = await service.prepare_for_indexing([doc])
document_id = results[0].id
result = await db_session.execute(select(Document).filter(Document.id == document_id))
reloaded = result.scalars().first()
assert str(reloaded.created_by_id) == str(db_user.id)
async def test_metadata_is_updated_when_content_changes(
db_session, db_search_space, make_connector_document
):
"""document_metadata is overwritten with the latest metadata when content changes."""
original = make_connector_document(
search_space_id=db_search_space.id,
source_markdown="## v1",
metadata={"status": "in_progress"},
)
service = IndexingPipelineService(session=db_session)
first = await service.prepare_for_indexing([original])
document_id = first[0].id
updated = make_connector_document(
search_space_id=db_search_space.id,
source_markdown="## v2",
metadata={"status": "done"},
)
await service.prepare_for_indexing([updated])
result = await db_session.execute(select(Document).filter(Document.id == document_id))
reloaded = result.scalars().first()
assert reloaded.document_metadata == {"status": "done"}
async def test_updated_at_advances_when_title_only_changes(
db_session, db_search_space, make_connector_document
):
"""updated_at advances even when only the title changes."""
original = make_connector_document(search_space_id=db_search_space.id, title="Old Title")
service = IndexingPipelineService(session=db_session)
first = await service.prepare_for_indexing([original])
document_id = first[0].id
result = await db_session.execute(select(Document).filter(Document.id == document_id))
updated_at_v1 = result.scalars().first().updated_at
renamed = make_connector_document(search_space_id=db_search_space.id, title="New Title")
await service.prepare_for_indexing([renamed])
result = await db_session.execute(select(Document).filter(Document.id == document_id))
updated_at_v2 = result.scalars().first().updated_at
assert updated_at_v2 > updated_at_v1
async def test_updated_at_advances_when_content_changes(
db_session, db_search_space, make_connector_document
):
"""updated_at advances when document content changes."""
original = make_connector_document(search_space_id=db_search_space.id, source_markdown="## v1")
service = IndexingPipelineService(session=db_session)
first = await service.prepare_for_indexing([original])
document_id = first[0].id
result = await db_session.execute(select(Document).filter(Document.id == document_id))
updated_at_v1 = result.scalars().first().updated_at
updated = make_connector_document(search_space_id=db_search_space.id, source_markdown="## v2")
await service.prepare_for_indexing([updated])
result = await db_session.execute(select(Document).filter(Document.id == document_id))
updated_at_v2 = result.scalars().first().updated_at
assert updated_at_v2 > updated_at_v1
async def test_same_content_from_different_source_skipped_in_single_batch(
db_session, db_search_space, make_connector_document
):
"""Two documents with identical content in the same batch result in only one being persisted."""
first = make_connector_document(
search_space_id=db_search_space.id,
unique_id="source-a",
source_markdown="## Shared content",
)
second = make_connector_document(
search_space_id=db_search_space.id,
unique_id="source-b",
source_markdown="## Shared content",
)
service = IndexingPipelineService(session=db_session)
results = await service.prepare_for_indexing([first, second])
assert len(results) == 1
result = await db_session.execute(
select(Document).filter(Document.search_space_id == db_search_space.id)
)
assert len(result.scalars().all()) == 1
async def test_same_content_from_different_source_is_skipped(
db_session, db_search_space, make_connector_document
):
"""A document with content identical to an already-indexed document is skipped."""
first = make_connector_document(
search_space_id=db_search_space.id,
unique_id="source-a",
source_markdown="## Shared content",
)
second = make_connector_document(
search_space_id=db_search_space.id,
unique_id="source-b",
source_markdown="## Shared content",
)
service = IndexingPipelineService(session=db_session)
await service.prepare_for_indexing([first])
results = await service.prepare_for_indexing([second])
assert results == []
result = await db_session.execute(
select(Document).filter(Document.search_space_id == db_search_space.id)
)
assert len(result.scalars().all()) == 1
@pytest.mark.usefixtures("patched_summarize_raises", "patched_embed_text", "patched_chunk_text")
async def test_failed_document_with_unchanged_content_is_requeued(
db_session, db_search_space, make_connector_document, mocker,
):
"""A FAILED document with unchanged content is re-queued as PENDING on the next run."""
doc = make_connector_document(search_space_id=db_search_space.id)
service = IndexingPipelineService(session=db_session)
# First run: document is created and indexing crashes → status = failed
prepared = await service.prepare_for_indexing([doc])
document_id = prepared[0].id
await service.index(prepared[0], doc, llm=mocker.Mock())
result = await db_session.execute(select(Document).filter(Document.id == document_id))
assert DocumentStatus.is_state(result.scalars().first().status, DocumentStatus.FAILED)
# Next run: same content, pipeline must re-queue the failed document
results = await service.prepare_for_indexing([doc])
assert len(results) == 1
assert results[0].id == document_id
result = await db_session.execute(select(Document).filter(Document.id == document_id))
assert DocumentStatus.is_state(result.scalars().first().status, DocumentStatus.PENDING)
async def test_title_and_content_change_updates_both_and_returns_document(
db_session, db_search_space, make_connector_document
):
"""When both title and content change, both are updated and the document is returned for re-indexing."""
original = make_connector_document(
search_space_id=db_search_space.id,
title="Original Title",
source_markdown="## v1",
)
service = IndexingPipelineService(session=db_session)
first = await service.prepare_for_indexing([original])
original_id = first[0].id
updated = make_connector_document(
search_space_id=db_search_space.id,
title="Updated Title",
source_markdown="## v2",
)
results = await service.prepare_for_indexing([updated])
assert len(results) == 1
assert results[0].id == original_id
result = await db_session.execute(select(Document).filter(Document.id == original_id))
reloaded = result.scalars().first()
assert reloaded.title == "Updated Title"
assert reloaded.source_markdown == "## v2"
async def test_one_bad_document_in_batch_does_not_prevent_others_from_being_persisted(
db_session, db_search_space, make_connector_document, monkeypatch,
):
"""
A per-document error during prepare_for_indexing must be isolated.
The two valid documents around the failing one must still be persisted.
"""
docs = [
make_connector_document(
search_space_id=db_search_space.id,
unique_id="good-1",
source_markdown="## Good doc 1",
),
make_connector_document(
search_space_id=db_search_space.id,
unique_id="will-fail",
source_markdown="## Bad doc",
),
make_connector_document(
search_space_id=db_search_space.id,
unique_id="good-2",
source_markdown="## Good doc 2",
),
]
def compute_content_hash_with_error(doc):
if doc.unique_id == "will-fail":
raise RuntimeError("Simulated per-document failure")
return real_compute_content_hash(doc)
monkeypatch.setattr(
"app.indexing_pipeline.indexing_pipeline_service.compute_content_hash",
compute_content_hash_with_error,
)
service = IndexingPipelineService(session=db_session)
results = await service.prepare_for_indexing(docs)
assert len(results) == 2
result = await db_session.execute(
select(Document).filter(Document.search_space_id == db_search_space.id)
)
assert len(result.scalars().all()) == 2

View file

View file

@ -0,0 +1,33 @@
import pytest
from unittest.mock import AsyncMock, MagicMock
@pytest.fixture
def patched_summarizer_chain(monkeypatch):
chain = MagicMock()
chain.ainvoke = AsyncMock(return_value=MagicMock(content="The summary."))
template = MagicMock()
template.__or__ = MagicMock(return_value=chain)
monkeypatch.setattr(
"app.indexing_pipeline.document_summarizer.SUMMARY_PROMPT_TEMPLATE",
template,
)
return chain
@pytest.fixture
def patched_chunker_instance(monkeypatch):
mock = MagicMock()
mock.chunk.return_value = [MagicMock(text="prose chunk")]
monkeypatch.setattr("app.indexing_pipeline.document_chunker.config.chunker_instance", mock)
return mock
@pytest.fixture
def patched_code_chunker_instance(monkeypatch):
mock = MagicMock()
mock.chunk.return_value = [MagicMock(text="code chunk")]
monkeypatch.setattr("app.indexing_pipeline.document_chunker.config.code_chunker_instance", mock)
return mock

View file

@ -0,0 +1,112 @@
import pytest
from pydantic import ValidationError
from app.db import DocumentType
from app.indexing_pipeline.connector_document import ConnectorDocument
def test_valid_document_created_with_required_fields():
"""All optional fields default correctly when only required fields are supplied."""
doc = ConnectorDocument(
title="Task",
source_markdown="## Task\n\nSome content.",
unique_id="task-1",
document_type=DocumentType.CLICKUP_CONNECTOR,
search_space_id=1,
connector_id=42,
created_by_id="00000000-0000-0000-0000-000000000001",
)
assert doc.should_summarize is True
assert doc.should_use_code_chunker is False
assert doc.metadata == {}
assert doc.connector_id == 42
assert doc.created_by_id == "00000000-0000-0000-0000-000000000001"
def test_omitting_created_by_id_raises():
"""Omitting created_by_id raises a validation error."""
with pytest.raises(ValidationError):
ConnectorDocument(
title="Task",
source_markdown="## Content",
unique_id="task-1",
document_type=DocumentType.CLICKUP_CONNECTOR,
search_space_id=1,
connector_id=42,
)
def test_empty_source_markdown_raises():
"""Empty source_markdown raises a validation error."""
with pytest.raises(ValidationError):
ConnectorDocument(
title="Task",
source_markdown="",
unique_id="task-1",
document_type=DocumentType.CLICKUP_CONNECTOR,
search_space_id=1,
)
def test_whitespace_only_source_markdown_raises():
"""Whitespace-only source_markdown raises a validation error."""
with pytest.raises(ValidationError):
ConnectorDocument(
title="Task",
source_markdown=" \n\t ",
unique_id="task-1",
document_type=DocumentType.CLICKUP_CONNECTOR,
search_space_id=1,
)
def test_empty_title_raises():
"""Empty title raises a validation error."""
with pytest.raises(ValidationError):
ConnectorDocument(
title="",
source_markdown="## Content",
unique_id="task-1",
document_type=DocumentType.CLICKUP_CONNECTOR,
search_space_id=1,
)
def test_empty_created_by_id_raises():
"""Empty created_by_id raises a validation error."""
with pytest.raises(ValidationError):
ConnectorDocument(
title="Task",
source_markdown="## Content",
unique_id="task-1",
document_type=DocumentType.CLICKUP_CONNECTOR,
search_space_id=1,
connector_id=42,
created_by_id="",
)
def test_zero_search_space_id_raises():
"""search_space_id of zero raises a validation error."""
with pytest.raises(ValidationError):
ConnectorDocument(
title="Task",
source_markdown="## Content",
unique_id="task-1",
document_type=DocumentType.CLICKUP_CONNECTOR,
search_space_id=0,
connector_id=42,
created_by_id="00000000-0000-0000-0000-000000000001",
)
def test_empty_unique_id_raises():
"""Empty unique_id raises a validation error."""
with pytest.raises(ValidationError):
ConnectorDocument(
title="Task",
source_markdown="## Content",
unique_id="",
document_type=DocumentType.CLICKUP_CONNECTOR,
search_space_id=1,
)

View file

@ -0,0 +1,21 @@
import pytest
from app.indexing_pipeline.document_chunker import chunk_text
pytestmark = pytest.mark.unit
@pytest.mark.usefixtures("patched_chunker_instance", "patched_code_chunker_instance")
def test_uses_code_chunker_when_flag_is_true():
"""Code chunker is selected when use_code_chunker=True."""
result = chunk_text("def foo(): pass", use_code_chunker=True)
assert result == ["code chunk"]
@pytest.mark.usefixtures("patched_chunker_instance", "patched_code_chunker_instance")
def test_uses_default_chunker_when_flag_is_false():
"""Default prose chunker is selected when use_code_chunker=False."""
result = chunk_text("Some prose text.", use_code_chunker=False)
assert result == ["prose chunk"]

View file

@ -0,0 +1,48 @@
import pytest
from app.db import DocumentType
from app.indexing_pipeline.document_hashing import compute_content_hash, compute_unique_identifier_hash
pytestmark = pytest.mark.unit
def test_different_unique_id_produces_different_hash(make_connector_document):
"""Two documents with different unique_ids produce different identifier hashes."""
doc_a = make_connector_document(unique_id="id-001")
doc_b = make_connector_document(unique_id="id-002")
assert compute_unique_identifier_hash(doc_a) != compute_unique_identifier_hash(doc_b)
def test_different_search_space_produces_different_identifier_hash(make_connector_document):
"""Same document in different search spaces produces different identifier hashes."""
doc_a = make_connector_document(search_space_id=1)
doc_b = make_connector_document(search_space_id=2)
assert compute_unique_identifier_hash(doc_a) != compute_unique_identifier_hash(doc_b)
def test_different_document_type_produces_different_identifier_hash(make_connector_document):
"""Same unique_id with different document types produces different identifier hashes."""
doc_a = make_connector_document(document_type=DocumentType.CLICKUP_CONNECTOR)
doc_b = make_connector_document(document_type=DocumentType.NOTION_CONNECTOR)
assert compute_unique_identifier_hash(doc_a) != compute_unique_identifier_hash(doc_b)
def test_same_content_same_space_produces_same_content_hash(make_connector_document):
"""Identical content in the same search space always produces the same content hash."""
doc_a = make_connector_document(source_markdown="Hello world", search_space_id=1)
doc_b = make_connector_document(source_markdown="Hello world", search_space_id=1)
assert compute_content_hash(doc_a) == compute_content_hash(doc_b)
def test_same_content_different_space_produces_different_content_hash(make_connector_document):
"""Identical content in different search spaces produces different content hashes."""
doc_a = make_connector_document(source_markdown="Hello world", search_space_id=1)
doc_b = make_connector_document(source_markdown="Hello world", search_space_id=2)
assert compute_content_hash(doc_a) != compute_content_hash(doc_b)
def test_different_content_produces_different_content_hash(make_connector_document):
"""Different source markdown produces different content hashes."""
doc_a = make_connector_document(source_markdown="Original content")
doc_b = make_connector_document(source_markdown="Updated content")
assert compute_content_hash(doc_a) != compute_content_hash(doc_b)

View file

@ -0,0 +1,42 @@
import pytest
from unittest.mock import MagicMock
from app.indexing_pipeline.document_summarizer import summarize_document
pytestmark = pytest.mark.unit
@pytest.mark.usefixtures("patched_summarizer_chain")
async def test_without_metadata_returns_raw_summary():
"""Summarizer returns the LLM output directly when no metadata is provided."""
result = await summarize_document("# Content", llm=MagicMock(model="gpt-4"))
assert result == "The summary."
@pytest.mark.usefixtures("patched_summarizer_chain")
async def test_with_metadata_includes_metadata_values_in_output():
"""Non-empty metadata values are prepended to the summary output."""
result = await summarize_document(
"# Content",
llm=MagicMock(model="gpt-4"),
metadata={"author": "Alice", "source": "Notion"},
)
assert "Alice" in result
assert "Notion" in result
@pytest.mark.usefixtures("patched_summarizer_chain")
async def test_with_metadata_omits_empty_fields_from_output():
"""Empty metadata fields are omitted from the summary output."""
result = await summarize_document(
"# Content",
llm=MagicMock(model="gpt-4"),
metadata={"author": "Alice", "description": ""},
)
assert "Alice" in result
assert "description" not in result.lower()

6655
surfsense_backend/uv.lock generated

File diff suppressed because it is too large Load diff