"""ARQ background task for processing knowledge base documents. Document conversion and chunking live in the Model Proxy Service (MPS); this task downloads the file from S3, calls MPS, then handles the embedding and DB writes locally. """ import os import tempfile from loguru import logger from api.db import db_client from api.db.models import KnowledgeBaseChunkModel from api.services.configuration.registry import ServiceProviders from api.services.gen_ai import AzureOpenAIEmbeddingService, OpenAIEmbeddingService from api.services.mps_service_key_client import mps_service_key_client from api.services.storage import storage_fs MAX_FILE_SIZE_BYTES = 5 * 1024 * 1024 async def process_knowledge_base_document( ctx, document_id: int, s3_key: str, organization_id: int, created_by_provider_id: str, max_tokens: int = 128, retrieval_mode: str = "chunked", ): """Process a knowledge base document via MPS: download, call MPS, embed, store. Args: ctx: ARQ context document_id: Database ID of the document s3_key: S3 key where the file is stored organization_id: Organization ID created_by_provider_id: Uploading user's provider ID (for OSS-mode auth to MPS) max_tokens: Maximum number of tokens per chunk (default: 128) retrieval_mode: "chunked" for vector search or "full_document" for full text """ logger.info( f"Processing knowledge base document: document_id={document_id}, " f"s3_key={s3_key}, org={organization_id}, mode={retrieval_mode}" ) temp_file_path = None try: await db_client.update_document_status(document_id, "processing") filename = s3_key.split("/")[-1] file_extension = os.path.splitext(filename)[1] or ".bin" temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) temp_file_path = temp_file.name temp_file.close() logger.info(f"Downloading file from S3: {s3_key}") download_success = await storage_fs.adownload_file(s3_key, temp_file_path) if not download_success: raise Exception(f"Failed to download file from S3: {s3_key}") if not os.path.exists(temp_file_path): raise FileNotFoundError(f"Downloaded file not found: {temp_file_path}") file_size = os.path.getsize(temp_file_path) logger.info(f"Downloaded file size: {file_size} bytes") if file_size > MAX_FILE_SIZE_BYTES: error_message = ( f"File size ({file_size / (1024 * 1024):.1f}MB) exceeds the " f"maximum allowed size of {MAX_FILE_SIZE_BYTES // (1024 * 1024)}MB." ) logger.warning(f"Document {document_id}: {error_message}") await db_client.update_document_status( document_id, "failed", error_message=error_message ) return file_hash = db_client.compute_file_hash(temp_file_path) mime_type = db_client.get_mime_type(temp_file_path) document = await db_client.get_document_by_id(document_id) if not document: raise Exception(f"Document {document_id} not found") # Reject duplicates (same hash already ingested for this org). existing_doc = await db_client.get_document_by_hash(file_hash, organization_id) if existing_doc and existing_doc.id != document_id: error_message = ( f"This file is a duplicate of '{existing_doc.filename}'. " f"Please delete the duplicate files and consolidate them into a " f"single unique file before uploading." ) logger.warning( f"Duplicate document detected: {document_id} is duplicate of " f"{existing_doc.id} ({existing_doc.filename})" ) await db_client.update_document_metadata( document_id, file_size_bytes=file_size, file_hash=file_hash, mime_type=mime_type, ) await db_client.update_document_status( document_id, "failed", error_message=error_message, docling_metadata={ "duplicate_of": existing_doc.document_uuid, "duplicate_filename": existing_doc.filename, }, ) return await db_client.update_document_metadata( document_id, file_size_bytes=file_size, file_hash=file_hash, mime_type=mime_type, ) logger.info(f"Delegating document processing to MPS (mode={retrieval_mode})") mps_response = await mps_service_key_client.process_document( file_path=temp_file_path, filename=filename, content_type=mime_type or "application/octet-stream", retrieval_mode=retrieval_mode, max_tokens=max_tokens, organization_id=organization_id, created_by=created_by_provider_id, ) docling_metadata = mps_response.get("docling_metadata", {}) if retrieval_mode == "full_document": full_text = mps_response.get("full_text") or "" await db_client.update_document_full_text(document_id, full_text) await db_client.update_document_status( document_id, "completed", total_chunks=0, docling_metadata=docling_metadata, ) logger.info( f"Successfully processed full_document {document_id}. " f"Text length: {len(full_text)} chars" ) return # Chunked mode: fetch user embedding config, embed, and persist chunks. embeddings_provider = None embeddings_api_key = None embeddings_model = None embeddings_base_url = None embeddings_endpoint = None embeddings_api_version = None if document.created_by: user_config = await db_client.get_user_configurations(document.created_by) if user_config.embeddings: embeddings_provider = getattr(user_config.embeddings, "provider", None) embeddings_api_key = user_config.embeddings.api_key embeddings_model = user_config.embeddings.model embeddings_base_url = getattr(user_config.embeddings, "base_url", None) embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None) embeddings_api_version = getattr(user_config.embeddings, "api_version", None) logger.info( f"Using user embeddings config: provider={embeddings_provider}, " f"model={embeddings_model}" ) if not embeddings_api_key: error_message = ( "API key not configured. Please set your API key in " "Model Configurations > Embedding to process documents." ) logger.warning(f"Document {document_id}: {error_message}") await db_client.update_document_status( document_id, "failed", error_message=error_message ) return if embeddings_provider == ServiceProviders.AZURE.value and embeddings_endpoint: embedding_service = AzureOpenAIEmbeddingService( db_client=db_client, api_key=embeddings_api_key, endpoint=embeddings_endpoint, model_id=embeddings_model or "text-embedding-3-small", api_version=embeddings_api_version or "2024-02-15-preview", ) else: embedding_service = OpenAIEmbeddingService( db_client=db_client, api_key=embeddings_api_key, model_id=embeddings_model or "text-embedding-3-small", base_url=embeddings_base_url, ) mps_chunks = mps_response.get("chunks", []) if not mps_chunks: logger.warning(f"Document {document_id}: MPS returned zero chunks") chunk_records = [] chunk_texts = [] for chunk in mps_chunks: contextualized = chunk.get("contextualized_text") or chunk["chunk_text"] chunk_records.append( KnowledgeBaseChunkModel( document_id=document_id, organization_id=organization_id, chunk_text=chunk["chunk_text"], contextualized_text=contextualized, chunk_index=chunk["chunk_index"], chunk_metadata=chunk.get("chunk_metadata") or {}, embedding_model=embedding_service.get_model_id(), embedding_dimension=embedding_service.get_embedding_dimension(), token_count=chunk.get("token_count", 0), ) ) chunk_texts.append(contextualized) logger.info( f"Generating embeddings for {len(chunk_texts)} chunks " f"using {embedding_service.get_model_id()}" ) embeddings = await embedding_service.embed_texts(chunk_texts) for chunk_record, embedding in zip(chunk_records, embeddings): chunk_record.embedding = embedding logger.info("Storing chunks in database") await db_client.create_chunks_batch(chunk_records) await db_client.update_document_status( document_id, "completed", total_chunks=len(chunk_records), docling_metadata=docling_metadata, ) logger.info( f"Successfully processed knowledge base document {document_id}. " f"Total chunks: {len(chunk_records)}" ) except Exception as e: logger.error( f"Error processing knowledge base document {document_id}: {e}", exc_info=True, ) await db_client.update_document_status( document_id, "failed", error_message=str(e) ) raise finally: if temp_file_path and os.path.exists(temp_file_path): try: os.remove(temp_file_path) logger.debug(f"Cleaned up temp file: {temp_file_path}") except Exception as e: logger.warning(f"Failed to clean up temp file {temp_file_path}: {e}")