mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-30 21:59:46 +02:00
feat: simplified document upload handling
- Introduced a new endpoint for batch document status retrieval, allowing users to check the status of multiple documents in a search space. - Enhanced the document upload process to return duplicate document IDs and improved response structure. - Updated schemas to include new response models for document status. - Removed unused attachment processing code from chat routes and UI components to streamline functionality.
This commit is contained in:
parent
d11e76aaa1
commit
c979609041
15 changed files with 475 additions and 1090 deletions
|
|
@ -18,6 +18,8 @@ from app.db import (
|
|||
)
|
||||
from app.schemas import (
|
||||
DocumentRead,
|
||||
DocumentStatusBatchResponse,
|
||||
DocumentStatusItemRead,
|
||||
DocumentsCreate,
|
||||
DocumentStatusSchema,
|
||||
DocumentTitleRead,
|
||||
|
|
@ -148,6 +150,7 @@ async def create_documents_file_upload(
|
|||
tuple[Document, str, str]
|
||||
] = [] # (document, temp_path, filename)
|
||||
skipped_duplicates = 0
|
||||
duplicate_document_ids: list[int] = []
|
||||
|
||||
# ===== PHASE 1: Create pending documents for all files =====
|
||||
# This makes ALL documents visible in the UI immediately with pending status
|
||||
|
|
@ -182,6 +185,7 @@ async def create_documents_file_upload(
|
|||
# True duplicate — content already indexed, skip
|
||||
os.unlink(temp_path)
|
||||
skipped_duplicates += 1
|
||||
duplicate_document_ids.append(existing.id)
|
||||
continue
|
||||
|
||||
# Existing document is stuck (failed/pending/processing)
|
||||
|
|
@ -255,6 +259,7 @@ async def create_documents_file_upload(
|
|||
return {
|
||||
"message": "Files uploaded for processing",
|
||||
"document_ids": [doc.id for doc in created_documents],
|
||||
"duplicate_document_ids": duplicate_document_ids,
|
||||
"total_files": len(files),
|
||||
"pending_files": len(files_to_process),
|
||||
"skipped_duplicates": skipped_duplicates,
|
||||
|
|
@ -678,6 +683,74 @@ async def search_document_titles(
|
|||
) from e
|
||||
|
||||
|
||||
@router.get("/documents/status", response_model=DocumentStatusBatchResponse)
|
||||
async def get_documents_status(
|
||||
search_space_id: int,
|
||||
document_ids: str,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Batch status endpoint for documents in a search space.
|
||||
|
||||
Returns lightweight status info for the provided document IDs, intended for
|
||||
polling async ETL progress in chat upload flows.
|
||||
"""
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
search_space_id,
|
||||
Permission.DOCUMENTS_READ.value,
|
||||
"You don't have permission to read documents in this search space",
|
||||
)
|
||||
|
||||
# Parse comma-separated IDs (e.g. "1,2,3")
|
||||
parsed_ids = []
|
||||
for raw_id in document_ids.split(","):
|
||||
value = raw_id.strip()
|
||||
if not value:
|
||||
continue
|
||||
try:
|
||||
parsed_ids.append(int(value))
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid document id: {value}",
|
||||
) from None
|
||||
|
||||
if not parsed_ids:
|
||||
return DocumentStatusBatchResponse(items=[])
|
||||
|
||||
result = await session.execute(
|
||||
select(Document).filter(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.id.in_(parsed_ids),
|
||||
)
|
||||
)
|
||||
docs = result.scalars().all()
|
||||
|
||||
items = [
|
||||
DocumentStatusItemRead(
|
||||
id=doc.id,
|
||||
title=doc.title,
|
||||
document_type=doc.document_type,
|
||||
status=DocumentStatusSchema(
|
||||
state=(doc.status or {}).get("state", "ready"),
|
||||
reason=(doc.status or {}).get("reason"),
|
||||
),
|
||||
)
|
||||
for doc in docs
|
||||
]
|
||||
return DocumentStatusBatchResponse(items=items)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch document status: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/documents/type-counts")
|
||||
async def get_document_type_counts(
|
||||
search_space_id: int | None = None,
|
||||
|
|
|
|||
|
|
@ -8,16 +8,11 @@ These endpoints support the ThreadHistoryAdapter pattern from assistant-ui:
|
|||
- PUT /threads/{thread_id} - Update thread (rename, archive)
|
||||
- DELETE /threads/{thread_id} - Delete thread
|
||||
- POST /threads/{thread_id}/messages - Append message
|
||||
- POST /attachments/process - Process attachments for chat context
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import tempfile
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy import func, or_
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
|
|
@ -1047,7 +1042,6 @@ async def handle_new_chat(
|
|||
session=session,
|
||||
user_id=str(user.id),
|
||||
llm_config_id=llm_config_id,
|
||||
attachments=request.attachments,
|
||||
mentioned_document_ids=request.mentioned_document_ids,
|
||||
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
|
||||
needs_history_bootstrap=thread.needs_history_bootstrap,
|
||||
|
|
@ -1278,7 +1272,6 @@ async def regenerate_response(
|
|||
session=session,
|
||||
user_id=str(user.id),
|
||||
llm_config_id=llm_config_id,
|
||||
attachments=request.attachments,
|
||||
mentioned_document_ids=request.mentioned_document_ids,
|
||||
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
|
||||
checkpoint_id=target_checkpoint_id,
|
||||
|
|
@ -1334,184 +1327,3 @@ async def regenerate_response(
|
|||
detail=f"An unexpected error occurred during regeneration: {e!s}",
|
||||
) from None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Attachment Processing Endpoint
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.post("/attachments/process")
|
||||
async def process_attachment(
|
||||
file: UploadFile = File(...),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Process an attachment file and extract its content as markdown.
|
||||
|
||||
This endpoint uses the configured ETL service to parse files and return
|
||||
the extracted content that can be used as context in chat messages.
|
||||
|
||||
Supported file types depend on the configured ETL_SERVICE:
|
||||
- Markdown/Text files: .md, .markdown, .txt (always supported)
|
||||
- Audio files: .mp3, .mp4, .mpeg, .mpga, .m4a, .wav, .webm (if STT configured)
|
||||
- Documents: .pdf, .docx, .doc, .pptx, .xlsx (depends on ETL service)
|
||||
|
||||
Returns:
|
||||
JSON with attachment id, name, type, and extracted content
|
||||
"""
|
||||
from app.config import config as app_config
|
||||
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="No filename provided")
|
||||
|
||||
filename = file.filename
|
||||
attachment_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
# Save file to a temporary location
|
||||
file_ext = os.path.splitext(filename)[1].lower()
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as temp_file:
|
||||
temp_path = temp_file.name
|
||||
content = await file.read()
|
||||
temp_file.write(content)
|
||||
|
||||
extracted_content = ""
|
||||
|
||||
# Process based on file type
|
||||
if file_ext in (".md", ".markdown", ".txt"):
|
||||
# For text/markdown files, read content directly
|
||||
with open(temp_path, encoding="utf-8") as f:
|
||||
extracted_content = f.read()
|
||||
|
||||
elif file_ext in (".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm"):
|
||||
# Audio files - transcribe if STT service is configured
|
||||
if not app_config.STT_SERVICE:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail="Audio transcription is not configured. Please set STT_SERVICE.",
|
||||
)
|
||||
|
||||
stt_service_type = (
|
||||
"local" if app_config.STT_SERVICE.startswith("local/") else "external"
|
||||
)
|
||||
|
||||
if stt_service_type == "local":
|
||||
from app.services.stt_service import stt_service
|
||||
|
||||
result = stt_service.transcribe_file(temp_path)
|
||||
extracted_content = result.get("text", "")
|
||||
else:
|
||||
from litellm import atranscription
|
||||
|
||||
with open(temp_path, "rb") as audio_file:
|
||||
transcription_kwargs = {
|
||||
"model": app_config.STT_SERVICE,
|
||||
"file": audio_file,
|
||||
"api_key": app_config.STT_SERVICE_API_KEY,
|
||||
}
|
||||
if app_config.STT_SERVICE_API_BASE:
|
||||
transcription_kwargs["api_base"] = (
|
||||
app_config.STT_SERVICE_API_BASE
|
||||
)
|
||||
|
||||
transcription_response = await atranscription(
|
||||
**transcription_kwargs
|
||||
)
|
||||
extracted_content = transcription_response.get("text", "")
|
||||
|
||||
if extracted_content:
|
||||
extracted_content = (
|
||||
f"# Transcription of {filename}\n\n{extracted_content}"
|
||||
)
|
||||
|
||||
else:
|
||||
# Document files - use configured ETL service
|
||||
if app_config.ETL_SERVICE == "UNSTRUCTURED":
|
||||
from langchain_unstructured import UnstructuredLoader
|
||||
|
||||
from app.utils.document_converters import convert_document_to_markdown
|
||||
|
||||
loader = UnstructuredLoader(
|
||||
temp_path,
|
||||
mode="elements",
|
||||
post_processors=[],
|
||||
languages=["eng"],
|
||||
include_orig_elements=False,
|
||||
include_metadata=False,
|
||||
strategy="auto",
|
||||
)
|
||||
docs = await loader.aload()
|
||||
extracted_content = await convert_document_to_markdown(docs)
|
||||
|
||||
elif app_config.ETL_SERVICE == "LLAMACLOUD":
|
||||
from llama_cloud_services import LlamaParse
|
||||
from llama_cloud_services.parse.utils import ResultType
|
||||
|
||||
parser = LlamaParse(
|
||||
api_key=app_config.LLAMA_CLOUD_API_KEY,
|
||||
num_workers=1,
|
||||
verbose=False,
|
||||
language="en",
|
||||
result_type=ResultType.MD,
|
||||
)
|
||||
result = await parser.aparse(temp_path)
|
||||
markdown_documents = await result.aget_markdown_documents(
|
||||
split_by_page=False
|
||||
)
|
||||
|
||||
if markdown_documents:
|
||||
extracted_content = "\n\n".join(
|
||||
doc.text for doc in markdown_documents
|
||||
)
|
||||
|
||||
elif app_config.ETL_SERVICE == "DOCLING":
|
||||
from app.services.docling_service import create_docling_service
|
||||
|
||||
docling_service = create_docling_service()
|
||||
result = await docling_service.process_document(temp_path, filename)
|
||||
extracted_content = result.get("content", "")
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"ETL service not configured or unsupported file type: {file_ext}",
|
||||
)
|
||||
|
||||
# Clean up temp file
|
||||
with contextlib.suppress(Exception):
|
||||
os.unlink(temp_path)
|
||||
|
||||
if not extracted_content:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Could not extract content from file: {filename}",
|
||||
)
|
||||
|
||||
# Determine attachment type (must be one of: "image", "document", "file")
|
||||
# assistant-ui only supports these three types
|
||||
if file_ext in (".png", ".jpg", ".jpeg", ".gif", ".webp"):
|
||||
attachment_type = "image"
|
||||
else:
|
||||
# All other files (including audio, documents, text) are treated as "document"
|
||||
attachment_type = "document"
|
||||
|
||||
return {
|
||||
"id": attachment_id,
|
||||
"name": filename,
|
||||
"type": attachment_type,
|
||||
"content": extracted_content,
|
||||
"contentLength": len(extracted_content),
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
# Clean up temp file on error
|
||||
with contextlib.suppress(Exception):
|
||||
os.unlink(temp_path)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to process attachment: {e!s}",
|
||||
) from e
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ from .chunks import ChunkBase, ChunkCreate, ChunkRead, ChunkUpdate
|
|||
from .documents import (
|
||||
DocumentBase,
|
||||
DocumentRead,
|
||||
DocumentStatusBatchResponse,
|
||||
DocumentStatusItemRead,
|
||||
DocumentsCreate,
|
||||
DocumentStatusSchema,
|
||||
DocumentTitleRead,
|
||||
|
|
@ -105,6 +107,8 @@ __all__ = [
|
|||
# Document schemas
|
||||
"DocumentBase",
|
||||
"DocumentRead",
|
||||
"DocumentStatusBatchResponse",
|
||||
"DocumentStatusItemRead",
|
||||
"DocumentStatusSchema",
|
||||
"DocumentTitleRead",
|
||||
"DocumentTitleSearchResponse",
|
||||
|
|
|
|||
|
|
@ -99,3 +99,20 @@ class DocumentTitleSearchResponse(BaseModel):
|
|||
|
||||
items: list[DocumentTitleRead]
|
||||
has_more: bool
|
||||
|
||||
|
||||
class DocumentStatusItemRead(BaseModel):
|
||||
"""Lightweight document status payload for batch status polling."""
|
||||
|
||||
id: int
|
||||
title: str
|
||||
document_type: DocumentType
|
||||
status: DocumentStatusSchema
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class DocumentStatusBatchResponse(BaseModel):
|
||||
"""Batch status response for a set of document IDs."""
|
||||
|
||||
items: list[DocumentStatusItemRead]
|
||||
|
|
|
|||
|
|
@ -159,15 +159,6 @@ class ChatMessage(BaseModel):
|
|||
content: str
|
||||
|
||||
|
||||
class ChatAttachment(BaseModel):
|
||||
"""An attachment with its extracted content for chat context."""
|
||||
|
||||
id: str # Unique attachment ID
|
||||
name: str # Original filename
|
||||
type: str # Attachment type: document, image, audio
|
||||
content: str # Extracted markdown content from the file
|
||||
|
||||
|
||||
class NewChatRequest(BaseModel):
|
||||
"""Request schema for the deep agent chat endpoint."""
|
||||
|
||||
|
|
@ -175,9 +166,6 @@ class NewChatRequest(BaseModel):
|
|||
user_query: str
|
||||
search_space_id: int
|
||||
messages: list[ChatMessage] | None = None # Optional chat history from frontend
|
||||
attachments: list[ChatAttachment] | None = (
|
||||
None # Optional attachments with extracted content
|
||||
)
|
||||
mentioned_document_ids: list[int] | None = (
|
||||
None # Optional document IDs mentioned with @ in the chat
|
||||
)
|
||||
|
|
@ -201,7 +189,6 @@ class RegenerateRequest(BaseModel):
|
|||
user_query: str | None = (
|
||||
None # New user query (for edit). None = reload with same query
|
||||
)
|
||||
attachments: list[ChatAttachment] | None = None
|
||||
mentioned_document_ids: list[int] | None = None
|
||||
mentioned_surfsense_doc_ids: list[int] | None = None
|
||||
|
||||
|
|
|
|||
|
|
@ -28,7 +28,6 @@ from app.agents.new_chat.llm_config import (
|
|||
)
|
||||
from app.db import ChatVisibility, Document, SurfsenseDocsDocument
|
||||
from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
from app.schemas.new_chat import ChatAttachment
|
||||
from app.services.chat_session_state_service import (
|
||||
clear_ai_responding,
|
||||
set_ai_responding,
|
||||
|
|
@ -37,24 +36,6 @@ from app.services.connector_service import ConnectorService
|
|||
from app.services.new_streaming_service import VercelStreamingService
|
||||
from app.utils.content_utils import bootstrap_history_from_db
|
||||
|
||||
|
||||
def format_attachments_as_context(attachments: list[ChatAttachment]) -> str:
|
||||
"""Format attachments as context for the agent."""
|
||||
if not attachments:
|
||||
return ""
|
||||
|
||||
context_parts = ["<user_attachments>"]
|
||||
for i, attachment in enumerate(attachments, 1):
|
||||
context_parts.append(
|
||||
f"<attachment index='{i}' name='{attachment.name}' type='{attachment.type}'>"
|
||||
)
|
||||
context_parts.append(f"<![CDATA[{attachment.content}]]>")
|
||||
context_parts.append("</attachment>")
|
||||
context_parts.append("</user_attachments>")
|
||||
|
||||
return "\n".join(context_parts)
|
||||
|
||||
|
||||
def format_mentioned_documents_as_context(documents: list[Document]) -> str:
|
||||
"""
|
||||
Format mentioned documents as context for the agent.
|
||||
|
|
@ -203,7 +184,6 @@ async def stream_new_chat(
|
|||
session: AsyncSession,
|
||||
user_id: str | None = None,
|
||||
llm_config_id: int = -1,
|
||||
attachments: list[ChatAttachment] | None = None,
|
||||
mentioned_document_ids: list[int] | None = None,
|
||||
mentioned_surfsense_doc_ids: list[int] | None = None,
|
||||
checkpoint_id: str | None = None,
|
||||
|
|
@ -224,7 +204,6 @@ async def stream_new_chat(
|
|||
session: The database session
|
||||
user_id: The current user's UUID string (for memory tools and session state)
|
||||
llm_config_id: The LLM configuration ID (default: -1 for first global config)
|
||||
attachments: Optional attachments with extracted content
|
||||
needs_history_bootstrap: If True, load message history from DB (for cloned chats)
|
||||
mentioned_document_ids: Optional list of document IDs mentioned with @ in the chat
|
||||
mentioned_surfsense_doc_ids: Optional list of SurfSense doc IDs mentioned with @ in the chat
|
||||
|
|
@ -360,13 +339,10 @@ async def stream_new_chat(
|
|||
)
|
||||
mentioned_surfsense_docs = list(result.scalars().all())
|
||||
|
||||
# Format the user query with context (attachments + mentioned documents + surfsense docs)
|
||||
# Format the user query with context (mentioned documents + SurfSense docs)
|
||||
final_query = user_query
|
||||
context_parts = []
|
||||
|
||||
if attachments:
|
||||
context_parts.append(format_attachments_as_context(attachments))
|
||||
|
||||
if mentioned_documents:
|
||||
context_parts.append(
|
||||
format_mentioned_documents_as_context(mentioned_documents)
|
||||
|
|
@ -459,39 +435,20 @@ async def stream_new_chat(
|
|||
last_active_step_id = analyze_step_id
|
||||
|
||||
# Determine step title and action verb based on context
|
||||
if attachments and (mentioned_documents or mentioned_surfsense_docs):
|
||||
last_active_step_title = "Analyzing your content"
|
||||
action_verb = "Reading"
|
||||
elif attachments:
|
||||
last_active_step_title = "Reading your content"
|
||||
action_verb = "Reading"
|
||||
elif mentioned_documents or mentioned_surfsense_docs:
|
||||
if mentioned_documents or mentioned_surfsense_docs:
|
||||
last_active_step_title = "Analyzing referenced content"
|
||||
action_verb = "Analyzing"
|
||||
else:
|
||||
last_active_step_title = "Understanding your request"
|
||||
action_verb = "Processing"
|
||||
|
||||
# Build the message with inline context about attachments/documents
|
||||
# Build the message with inline context about referenced documents
|
||||
processing_parts = []
|
||||
|
||||
# Add the user query
|
||||
query_text = user_query[:80] + ("..." if len(user_query) > 80 else "")
|
||||
processing_parts.append(query_text)
|
||||
|
||||
# Add file attachment names inline
|
||||
if attachments:
|
||||
attachment_names = []
|
||||
for attachment in attachments:
|
||||
name = attachment.name
|
||||
if len(name) > 30:
|
||||
name = name[:27] + "..."
|
||||
attachment_names.append(name)
|
||||
if len(attachment_names) == 1:
|
||||
processing_parts.append(f"[{attachment_names[0]}]")
|
||||
else:
|
||||
processing_parts.append(f"[{len(attachment_names)} files]")
|
||||
|
||||
# Add mentioned document names inline
|
||||
if mentioned_documents:
|
||||
doc_names = []
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue