mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-29 19:35:20 +02:00
add file upload adapter and make index() return refreshed document
This commit is contained in:
parent
86ecb82c6e
commit
cad400be1b
6 changed files with 82 additions and 37 deletions
|
|
@ -0,0 +1,44 @@
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import DocumentStatus, DocumentType
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
|
||||
|
||||
async def index_uploaded_file(
|
||||
markdown_content: str,
|
||||
filename: str,
|
||||
etl_service: str,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
session: AsyncSession,
|
||||
llm,
|
||||
) -> None:
|
||||
connector_doc = ConnectorDocument(
|
||||
title=filename,
|
||||
source_markdown=markdown_content,
|
||||
unique_id=filename,
|
||||
document_type=DocumentType.FILE,
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=user_id,
|
||||
connector_id=None,
|
||||
should_summarize=True,
|
||||
should_use_code_chunker=False,
|
||||
fallback_summary=markdown_content[:4000],
|
||||
metadata={
|
||||
"file_name": filename,
|
||||
"etl_service": etl_service,
|
||||
"document_type": "File Document",
|
||||
},
|
||||
)
|
||||
|
||||
service = IndexingPipelineService(session)
|
||||
documents = await service.prepare_for_indexing([connector_doc])
|
||||
|
||||
if not documents:
|
||||
raise RuntimeError("prepare_for_indexing returned no documents")
|
||||
|
||||
indexed = await service.index(documents[0], connector_doc, llm)
|
||||
|
||||
if not DocumentStatus.is_state(indexed.status, DocumentStatus.READY):
|
||||
raise RuntimeError(indexed.status.get("message", "Indexing failed"))
|
||||
|
|
@ -14,7 +14,7 @@ class ConnectorDocument(BaseModel):
|
|||
should_use_code_chunker: bool = False
|
||||
fallback_summary: str | None = None
|
||||
metadata: dict = {}
|
||||
connector_id: int = Field(gt=0)
|
||||
connector_id: int | None = None
|
||||
created_by_id: str
|
||||
|
||||
@field_validator("title", "source_markdown", "unique_id", "created_by_id")
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import contextlib
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
|
|
@ -7,8 +8,14 @@ from app.db import Chunk, Document, DocumentStatus
|
|||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
from app.indexing_pipeline.document_chunker import chunk_text
|
||||
from app.indexing_pipeline.document_embedder import embed_text
|
||||
from app.indexing_pipeline.document_hashing import compute_content_hash, compute_unique_identifier_hash
|
||||
from app.indexing_pipeline.document_persistence import attach_chunks_to_document, rollback_and_persist_failure
|
||||
from app.indexing_pipeline.document_hashing import (
|
||||
compute_content_hash,
|
||||
compute_unique_identifier_hash,
|
||||
)
|
||||
from app.indexing_pipeline.document_persistence import (
|
||||
attach_chunks_to_document,
|
||||
rollback_and_persist_failure,
|
||||
)
|
||||
from app.indexing_pipeline.document_summarizer import summarize_document
|
||||
from app.indexing_pipeline.exceptions import (
|
||||
EMBEDDING_ERRORS,
|
||||
|
|
@ -74,7 +81,9 @@ class IndexingPipelineService:
|
|||
seen_hashes.add(unique_identifier_hash)
|
||||
|
||||
result = await self.session.execute(
|
||||
select(Document).filter(Document.unique_identifier_hash == unique_identifier_hash)
|
||||
select(Document).filter(
|
||||
Document.unique_identifier_hash == unique_identifier_hash
|
||||
)
|
||||
)
|
||||
existing = result.scalars().first()
|
||||
|
||||
|
|
@ -83,7 +92,9 @@ class IndexingPipelineService:
|
|||
if existing.title != connector_doc.title:
|
||||
existing.title = connector_doc.title
|
||||
existing.updated_at = datetime.now(UTC)
|
||||
if not DocumentStatus.is_state(existing.status, DocumentStatus.READY):
|
||||
if not DocumentStatus.is_state(
|
||||
existing.status, DocumentStatus.READY
|
||||
):
|
||||
existing.status = DocumentStatus.pending()
|
||||
existing.updated_at = datetime.now(UTC)
|
||||
documents.append(existing)
|
||||
|
|
@ -144,7 +155,7 @@ class IndexingPipelineService:
|
|||
|
||||
async def index(
|
||||
self, document: Document, connector_doc: ConnectorDocument, llm
|
||||
) -> None:
|
||||
) -> Document:
|
||||
"""
|
||||
Run summarization, embedding, and chunking for a document and persist the results.
|
||||
"""
|
||||
|
|
@ -192,20 +203,35 @@ class IndexingPipelineService:
|
|||
|
||||
except RETRYABLE_LLM_ERRORS as e:
|
||||
log_retryable_llm_error(ctx, e)
|
||||
await rollback_and_persist_failure(self.session, document, llm_retryable_message(e))
|
||||
await rollback_and_persist_failure(
|
||||
self.session, document, llm_retryable_message(e)
|
||||
)
|
||||
|
||||
except PERMANENT_LLM_ERRORS as e:
|
||||
log_permanent_llm_error(ctx, e)
|
||||
await rollback_and_persist_failure(self.session, document, llm_permanent_message(e))
|
||||
await rollback_and_persist_failure(
|
||||
self.session, document, llm_permanent_message(e)
|
||||
)
|
||||
|
||||
except RecursionError as e:
|
||||
log_chunking_overflow(ctx, e)
|
||||
await rollback_and_persist_failure(self.session, document, PipelineMessages.CHUNKING_OVERFLOW)
|
||||
await rollback_and_persist_failure(
|
||||
self.session, document, PipelineMessages.CHUNKING_OVERFLOW
|
||||
)
|
||||
|
||||
except EMBEDDING_ERRORS as e:
|
||||
log_embedding_error(ctx, e)
|
||||
await rollback_and_persist_failure(self.session, document, embedding_message(e))
|
||||
await rollback_and_persist_failure(
|
||||
self.session, document, embedding_message(e)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
log_unexpected_error(ctx, e)
|
||||
await rollback_and_persist_failure(self.session, document, safe_exception_message(e))
|
||||
await rollback_and_persist_failure(
|
||||
self.session, document, safe_exception_message(e)
|
||||
)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
await self.session.refresh(document)
|
||||
|
||||
return document
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
@dataclass
|
||||
class PipelineLogContext:
|
||||
connector_id: int
|
||||
connector_id: int | None
|
||||
search_space_id: int
|
||||
unique_id: str # always available from ConnectorDocument
|
||||
doc_id: int | None = None # set once the DB row exists (index phase only)
|
||||
|
|
|
|||
|
|
@ -22,18 +22,6 @@ def test_valid_document_created_with_required_fields():
|
|||
assert doc.created_by_id == "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
|
||||
def test_omitting_connector_id_raises():
|
||||
with pytest.raises(ValidationError):
|
||||
ConnectorDocument(
|
||||
title="Task",
|
||||
source_markdown="## Content",
|
||||
unique_id="task-1",
|
||||
document_type=DocumentType.CLICKUP_CONNECTOR,
|
||||
search_space_id=1,
|
||||
created_by_id="00000000-0000-0000-0000-000000000001",
|
||||
)
|
||||
|
||||
|
||||
def test_omitting_created_by_id_raises():
|
||||
with pytest.raises(ValidationError):
|
||||
ConnectorDocument(
|
||||
|
|
@ -92,19 +80,6 @@ def test_empty_created_by_id_raises():
|
|||
)
|
||||
|
||||
|
||||
def test_zero_connector_id_raises():
|
||||
with pytest.raises(ValidationError):
|
||||
ConnectorDocument(
|
||||
title="Task",
|
||||
source_markdown="## Content",
|
||||
unique_id="task-1",
|
||||
document_type=DocumentType.CLICKUP_CONNECTOR,
|
||||
search_space_id=1,
|
||||
connector_id=0,
|
||||
created_by_id="00000000-0000-0000-0000-000000000001",
|
||||
)
|
||||
|
||||
|
||||
def test_zero_search_space_id_raises():
|
||||
with pytest.raises(ValidationError):
|
||||
ConnectorDocument(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue