"""ARQ background task for processing knowledge base documents.""" import os import tempfile from docling.chunking import HybridChunker from docling.document_converter import DocumentConverter from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer from loguru import logger from transformers import AutoTokenizer from api.db import db_client from api.db.models import KnowledgeBaseChunkModel from api.services.gen_ai import OpenAIEmbeddingService from api.services.storage import storage_fs # For tokenization/chunking TOKENIZER_MODEL = "sentence-transformers/all-MiniLM-L6-v2" async def process_knowledge_base_document( ctx, document_id: int, s3_key: str, organization_id: int, max_tokens: int = 128, ): """Process a knowledge base document: download, chunk, embed, and store. Args: ctx: ARQ context document_id: Database ID of the document s3_key: S3 key where the file is stored organization_id: Organization ID max_tokens: Maximum number of tokens per chunk (default: 128) """ logger.info( f"Starting knowledge base document processing for document_id={document_id}, " f"s3_key={s3_key}, organization_id={organization_id}" ) temp_file_path = None try: # Update status to processing await db_client.update_document_status(document_id, "processing") # Extract file extension from S3 key filename = s3_key.split("/")[-1] file_extension = ( os.path.splitext(filename)[1] or ".bin" ) # Default to .bin if no extension # Create temp file for download with correct extension temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) temp_file_path = temp_file.name temp_file.close() # Download file from S3 logger.info(f"Downloading file from S3: {s3_key}") download_success = await storage_fs.adownload_file(s3_key, temp_file_path) if not download_success: raise Exception(f"Failed to download file from S3: {s3_key}") if not os.path.exists(temp_file_path): raise FileNotFoundError(f"Downloaded file not found: {temp_file_path}") file_size = os.path.getsize(temp_file_path) logger.info(f"Downloaded file size: {file_size} bytes") # Compute file hash and get mime type file_hash = db_client.compute_file_hash(temp_file_path) mime_type = db_client.get_mime_type(temp_file_path) filename = s3_key.split("/")[-1] # Get document record document = await db_client.get_document_by_id(document_id) if not document: raise Exception(f"Document {document_id} not found") # Check if a document with this hash already exists (reject duplicates) existing_doc = await db_client.get_document_by_hash(file_hash, organization_id) if existing_doc and existing_doc.id != document_id: error_message = ( f"This file is a duplicate of '{existing_doc.filename}'. " f"Please delete the duplicate files and consolidate them into a single unique file before uploading." ) logger.warning( f"Duplicate document detected: {document_id} is duplicate of {existing_doc.id} " f"({existing_doc.filename})" ) # Update file metadata await db_client.update_document_metadata( document_id, file_size_bytes=file_size, file_hash=file_hash, mime_type=mime_type, ) # Mark as failed with duplicate error message await db_client.update_document_status( document_id, "failed", error_message=error_message, docling_metadata={ "duplicate_of": existing_doc.document_uuid, "duplicate_filename": existing_doc.filename, }, ) return # Update document with file metadata await db_client.update_document_metadata( document_id, file_size_bytes=file_size, file_hash=file_hash, mime_type=mime_type, ) # Initialize the OpenAI embedding service logger.info( f"Initializing OpenAI embedding service with max_tokens={max_tokens}" ) # Try to get user's embeddings configuration embeddings_api_key = None embeddings_model = None embeddings_base_url = None if document.created_by: user_config = await db_client.get_user_configurations(document.created_by) if user_config.embeddings: embeddings_api_key = user_config.embeddings.api_key embeddings_model = user_config.embeddings.model embeddings_base_url = getattr(user_config.embeddings, "base_url", None) logger.info(f"Using user embeddings config: model={embeddings_model}") # Check if API key is configured if not embeddings_api_key: error_message = ( "OpenAI API key not configured. Please set your API key in " "Model Configurations > Embedding to process documents." ) logger.warning(f"Document {document_id}: {error_message}") await db_client.update_document_status( document_id, "failed", error_message=error_message ) return service = OpenAIEmbeddingService( db_client=db_client, max_tokens=max_tokens, api_key=embeddings_api_key, model_id=embeddings_model or "text-embedding-3-small", base_url=embeddings_base_url, ) # Step 1: Convert document with docling logger.info("Converting document with docling") converter = DocumentConverter() conversion_result = converter.convert(temp_file_path) doc = conversion_result.document # Store docling metadata docling_metadata = { "num_pages": len(doc.pages) if hasattr(doc, "pages") else None, "document_type": type(doc).__name__, } # Step 2: Initialize tokenizer for chunking logger.info( f"Loading tokenizer: {TOKENIZER_MODEL} with max_tokens={max_tokens}" ) tokenizer = HuggingFaceTokenizer( tokenizer=AutoTokenizer.from_pretrained(TOKENIZER_MODEL), max_tokens=max_tokens, ) # Step 3: Initialize chunker logger.info(f"Initializing HybridChunker with max_tokens={max_tokens}") chunker = HybridChunker(tokenizer=tokenizer) # Step 4: Chunk the document logger.info(f"Chunking document with max_tokens={max_tokens}") chunks = list(chunker.chunk(dl_doc=doc)) total_chunks = len(chunks) logger.info(f"Generated {total_chunks} chunks") # Step 5: Process each chunk chunk_texts = [] chunk_records = [] token_counts = [] for i, chunk in enumerate(chunks): chunk_text = chunk.text contextualized_text = chunker.contextualize(chunk=chunk) # Calculate token count text_to_tokenize = ( contextualized_text if contextualized_text else chunk_text ) token_count = len( tokenizer.tokenizer.encode(text_to_tokenize, add_special_tokens=False) ) token_counts.append(token_count) # Prepare chunk metadata chunk_metadata = {} if hasattr(chunk, "meta") and chunk.meta: chunk_metadata = { "doc_items": ( [str(item) for item in chunk.meta.doc_items] if hasattr(chunk.meta, "doc_items") else [] ), "headings": ( chunk.meta.headings if hasattr(chunk.meta, "headings") else [] ), } # Create chunk record (without embedding yet) chunk_record = KnowledgeBaseChunkModel( document_id=document_id, organization_id=organization_id, chunk_text=chunk_text, contextualized_text=contextualized_text, chunk_index=i, chunk_metadata=chunk_metadata, embedding_model=service.get_model_id(), embedding_dimension=service.get_embedding_dimension(), token_count=token_count, ) chunk_records.append(chunk_record) chunk_texts.append(text_to_tokenize) # Log chunk statistics if token_counts: avg_tokens = sum(token_counts) / len(token_counts) min_tokens = min(token_counts) max_tokens_actual = max(token_counts) logger.info("Chunk token statistics:") logger.info(f" - Average: {avg_tokens:.1f} tokens") logger.info(f" - Min: {min_tokens} tokens") logger.info(f" - Max: {max_tokens_actual} tokens") # Step 6: Generate embeddings using OpenAI logger.info(f"Generating embeddings using {service.get_model_id()}") embeddings = await service.embed_texts(chunk_texts) # Step 7: Attach embeddings to chunk records for chunk_record, embedding in zip(chunk_records, embeddings): chunk_record.embedding = embedding # Step 8: Save chunks in database logger.info("Storing chunks in database") await db_client.create_chunks_batch(chunk_records) # Step 9: Update document status to completed await db_client.update_document_status( document_id, "completed", total_chunks=total_chunks, docling_metadata=docling_metadata, ) logger.info( f"Successfully processed knowledge base document {document_id}. " f"Total chunks: {total_chunks}" ) except Exception as e: logger.error( f"Error processing knowledge base document {document_id}: {e}", exc_info=True, ) # Update document status to failed await db_client.update_document_status( document_id, "failed", error_message=str(e) ) raise finally: # Clean up temp file if temp_file_path and os.path.exists(temp_file_path): try: os.remove(temp_file_path) logger.debug(f"Cleaned up temp file: {temp_file_path}") except Exception as e: logger.warning(f"Failed to clean up temp file {temp_file_path}: {e}")