add file upload adapter and make index() return refreshed document

This commit is contained in:
CREDO23 2026-02-25 19:56:59 +02:00
parent 86ecb82c6e
commit cad400be1b
6 changed files with 82 additions and 37 deletions

View file

@ -0,0 +1,44 @@
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import DocumentStatus, DocumentType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
async def index_uploaded_file(
markdown_content: str,
filename: str,
etl_service: str,
search_space_id: int,
user_id: str,
session: AsyncSession,
llm,
) -> None:
connector_doc = ConnectorDocument(
title=filename,
source_markdown=markdown_content,
unique_id=filename,
document_type=DocumentType.FILE,
search_space_id=search_space_id,
created_by_id=user_id,
connector_id=None,
should_summarize=True,
should_use_code_chunker=False,
fallback_summary=markdown_content[:4000],
metadata={
"file_name": filename,
"etl_service": etl_service,
"document_type": "File Document",
},
)
service = IndexingPipelineService(session)
documents = await service.prepare_for_indexing([connector_doc])
if not documents:
raise RuntimeError("prepare_for_indexing returned no documents")
indexed = await service.index(documents[0], connector_doc, llm)
if not DocumentStatus.is_state(indexed.status, DocumentStatus.READY):
raise RuntimeError(indexed.status.get("message", "Indexing failed"))

View file

@ -14,7 +14,7 @@ class ConnectorDocument(BaseModel):
should_use_code_chunker: bool = False
fallback_summary: str | None = None
metadata: dict = {}
connector_id: int = Field(gt=0)
connector_id: int | None = None
created_by_id: str
@field_validator("title", "source_markdown", "unique_id", "created_by_id")

View file

@ -1,3 +1,4 @@
import contextlib
from datetime import UTC, datetime
from sqlalchemy import delete, select
@ -7,8 +8,14 @@ from app.db import Chunk, Document, DocumentStatus
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_chunker import chunk_text
from app.indexing_pipeline.document_embedder import embed_text
from app.indexing_pipeline.document_hashing import compute_content_hash, compute_unique_identifier_hash
from app.indexing_pipeline.document_persistence import attach_chunks_to_document, rollback_and_persist_failure
from app.indexing_pipeline.document_hashing import (
compute_content_hash,
compute_unique_identifier_hash,
)
from app.indexing_pipeline.document_persistence import (
attach_chunks_to_document,
rollback_and_persist_failure,
)
from app.indexing_pipeline.document_summarizer import summarize_document
from app.indexing_pipeline.exceptions import (
EMBEDDING_ERRORS,
@ -74,7 +81,9 @@ class IndexingPipelineService:
seen_hashes.add(unique_identifier_hash)
result = await self.session.execute(
select(Document).filter(Document.unique_identifier_hash == unique_identifier_hash)
select(Document).filter(
Document.unique_identifier_hash == unique_identifier_hash
)
)
existing = result.scalars().first()
@ -83,7 +92,9 @@ class IndexingPipelineService:
if existing.title != connector_doc.title:
existing.title = connector_doc.title
existing.updated_at = datetime.now(UTC)
if not DocumentStatus.is_state(existing.status, DocumentStatus.READY):
if not DocumentStatus.is_state(
existing.status, DocumentStatus.READY
):
existing.status = DocumentStatus.pending()
existing.updated_at = datetime.now(UTC)
documents.append(existing)
@ -144,7 +155,7 @@ class IndexingPipelineService:
async def index(
self, document: Document, connector_doc: ConnectorDocument, llm
) -> None:
) -> Document:
"""
Run summarization, embedding, and chunking for a document and persist the results.
"""
@ -192,20 +203,35 @@ class IndexingPipelineService:
except RETRYABLE_LLM_ERRORS as e:
log_retryable_llm_error(ctx, e)
await rollback_and_persist_failure(self.session, document, llm_retryable_message(e))
await rollback_and_persist_failure(
self.session, document, llm_retryable_message(e)
)
except PERMANENT_LLM_ERRORS as e:
log_permanent_llm_error(ctx, e)
await rollback_and_persist_failure(self.session, document, llm_permanent_message(e))
await rollback_and_persist_failure(
self.session, document, llm_permanent_message(e)
)
except RecursionError as e:
log_chunking_overflow(ctx, e)
await rollback_and_persist_failure(self.session, document, PipelineMessages.CHUNKING_OVERFLOW)
await rollback_and_persist_failure(
self.session, document, PipelineMessages.CHUNKING_OVERFLOW
)
except EMBEDDING_ERRORS as e:
log_embedding_error(ctx, e)
await rollback_and_persist_failure(self.session, document, embedding_message(e))
await rollback_and_persist_failure(
self.session, document, embedding_message(e)
)
except Exception as e:
log_unexpected_error(ctx, e)
await rollback_and_persist_failure(self.session, document, safe_exception_message(e))
await rollback_and_persist_failure(
self.session, document, safe_exception_message(e)
)
with contextlib.suppress(Exception):
await self.session.refresh(document)
return document

View file

@ -6,7 +6,7 @@ logger = logging.getLogger(__name__)
@dataclass
class PipelineLogContext:
connector_id: int
connector_id: int | None
search_space_id: int
unique_id: str # always available from ConnectorDocument
doc_id: int | None = None # set once the DB row exists (index phase only)

View file

@ -22,18 +22,6 @@ def test_valid_document_created_with_required_fields():
assert doc.created_by_id == "00000000-0000-0000-0000-000000000001"
def test_omitting_connector_id_raises():
with pytest.raises(ValidationError):
ConnectorDocument(
title="Task",
source_markdown="## Content",
unique_id="task-1",
document_type=DocumentType.CLICKUP_CONNECTOR,
search_space_id=1,
created_by_id="00000000-0000-0000-0000-000000000001",
)
def test_omitting_created_by_id_raises():
with pytest.raises(ValidationError):
ConnectorDocument(
@ -92,19 +80,6 @@ def test_empty_created_by_id_raises():
)
def test_zero_connector_id_raises():
with pytest.raises(ValidationError):
ConnectorDocument(
title="Task",
source_markdown="## Content",
unique_id="task-1",
document_type=DocumentType.CLICKUP_CONNECTOR,
search_space_id=1,
connector_id=0,
created_by_id="00000000-0000-0000-0000-000000000001",
)
def test_zero_search_space_id_raises():
with pytest.raises(ValidationError):
ConnectorDocument(