dograh/api/tasks/knowledge_base_processing.py

312 lines
12 KiB
Python
Raw Normal View History

"""ARQ background task for processing knowledge base documents."""
import os
import tempfile
from typing import Literal
from docling.chunking import HybridChunker
from docling.document_converter import DocumentConverter
from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer
from loguru import logger
from transformers import AutoTokenizer
from api.db import db_client
from api.db.models import KnowledgeBaseChunkModel
from api.services.gen_ai import (
OpenAIEmbeddingService,
SentenceTransformerEmbeddingService,
)
from api.services.storage import storage_fs
# For tokenization/chunking - use SentenceTransformer tokenizer as baseline
TOKENIZER_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
async def process_knowledge_base_document(
ctx,
document_id: int,
s3_key: str,
organization_id: int,
max_tokens: int = 128,
embedding_service: Literal["sentence_transformer", "openai"] = "openai",
):
"""Process a knowledge base document: download, chunk, embed, and store.
Args:
ctx: ARQ context
document_id: Database ID of the document
s3_key: S3 key where the file is stored
organization_id: Organization ID
max_tokens: Maximum number of tokens per chunk (default: 128)
embedding_service: Embedding service to use (default: "openai")
- "openai": Use OpenAI text-embedding-3-small (1536-dim, requires API key)
- "sentence_transformer": Use SentenceTransformer (all-MiniLM-L6-v2, 384-dim, free)
"""
logger.info(
f"Starting knowledge base document processing for document_id={document_id}, "
f"s3_key={s3_key}, organization_id={organization_id}"
)
temp_file_path = None
try:
# Update status to processing
await db_client.update_document_status(document_id, "processing")
# Extract file extension from S3 key
filename = s3_key.split("/")[-1]
file_extension = (
os.path.splitext(filename)[1] or ".bin"
) # Default to .bin if no extension
# Create temp file for download with correct extension
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=file_extension)
temp_file_path = temp_file.name
temp_file.close()
# Download file from S3
logger.info(f"Downloading file from S3: {s3_key}")
download_success = await storage_fs.adownload_file(s3_key, temp_file_path)
if not download_success:
raise Exception(f"Failed to download file from S3: {s3_key}")
if not os.path.exists(temp_file_path):
raise FileNotFoundError(f"Downloaded file not found: {temp_file_path}")
file_size = os.path.getsize(temp_file_path)
logger.info(f"Downloaded file size: {file_size} bytes")
# Compute file hash and get mime type
file_hash = db_client.compute_file_hash(temp_file_path)
mime_type = db_client.get_mime_type(temp_file_path)
filename = s3_key.split("/")[-1]
# Get document record
document = await db_client.get_document_by_id(document_id)
if not document:
raise Exception(f"Document {document_id} not found")
# Check if a document with this hash already exists (reject duplicates)
existing_doc = await db_client.get_document_by_hash(file_hash, organization_id)
if existing_doc and existing_doc.id != document_id:
error_message = (
f"This file is a duplicate of '{existing_doc.filename}'. "
f"Please delete the duplicate files and consolidate them into a single unique file before uploading."
)
logger.warning(
f"Duplicate document detected: {document_id} is duplicate of {existing_doc.id} "
f"({existing_doc.filename})"
)
# Update file metadata
await db_client.update_document_metadata(
document_id,
file_size_bytes=file_size,
file_hash=file_hash,
mime_type=mime_type,
)
# Mark as failed with duplicate error message
await db_client.update_document_status(
document_id,
"failed",
error_message=error_message,
docling_metadata={
"duplicate_of": existing_doc.document_uuid,
"duplicate_filename": existing_doc.filename,
},
)
return
# Update document with file metadata
await db_client.update_document_metadata(
document_id,
file_size_bytes=file_size,
file_hash=file_hash,
mime_type=mime_type,
)
# Initialize the embedding service based on the parameter
if embedding_service == "openai":
logger.info(
f"Initializing OpenAI embedding service with max_tokens={max_tokens}"
)
# Try to get user's embeddings configuration
embeddings_api_key = None
embeddings_model = None
if document.created_by:
user_config = await db_client.get_user_configurations(
document.created_by
)
if user_config.embeddings:
embeddings_api_key = user_config.embeddings.api_key
embeddings_model = user_config.embeddings.model
logger.info(
f"Using user embeddings config: model={embeddings_model}"
)
# Check if API key is configured
if not embeddings_api_key:
error_message = (
"OpenAI API key not configured. Please set your API key in "
"Model Configurations > Embedding to process documents."
)
logger.warning(f"Document {document_id}: {error_message}")
await db_client.update_document_status(
document_id, "failed", error_message=error_message
)
return
service = OpenAIEmbeddingService(
db_client=db_client,
max_tokens=max_tokens,
api_key=embeddings_api_key,
model_id=embeddings_model or "text-embedding-3-small",
)
elif embedding_service == "sentence_transformer":
logger.info(
f"Initializing SentenceTransformer embedding service with max_tokens={max_tokens}"
)
service = SentenceTransformerEmbeddingService(
db_client=db_client,
max_tokens=max_tokens,
)
else:
raise ValueError(
f"Invalid embedding_service: {embedding_service}. "
f"Must be 'sentence_transformer' or 'openai'"
)
# Step 1: Convert document with docling
logger.info("Converting document with docling")
converter = DocumentConverter()
conversion_result = converter.convert(temp_file_path)
doc = conversion_result.document
# Store docling metadata
docling_metadata = {
"num_pages": len(doc.pages) if hasattr(doc, "pages") else None,
"document_type": type(doc).__name__,
}
# Step 2: Initialize tokenizer for chunking
logger.info(
f"Loading tokenizer: {TOKENIZER_MODEL} with max_tokens={max_tokens}"
)
tokenizer = HuggingFaceTokenizer(
tokenizer=AutoTokenizer.from_pretrained(TOKENIZER_MODEL),
max_tokens=max_tokens,
)
# Step 3: Initialize chunker
logger.info(f"Initializing HybridChunker with max_tokens={max_tokens}")
chunker = HybridChunker(tokenizer=tokenizer)
# Step 4: Chunk the document
logger.info(f"Chunking document with max_tokens={max_tokens}")
chunks = list(chunker.chunk(dl_doc=doc))
total_chunks = len(chunks)
logger.info(f"Generated {total_chunks} chunks")
# Step 5: Process each chunk
chunk_texts = []
chunk_records = []
token_counts = []
for i, chunk in enumerate(chunks):
chunk_text = chunk.text
contextualized_text = chunker.contextualize(chunk=chunk)
# Calculate token count
text_to_tokenize = (
contextualized_text if contextualized_text else chunk_text
)
token_count = len(
tokenizer.tokenizer.encode(text_to_tokenize, add_special_tokens=False)
)
token_counts.append(token_count)
# Prepare chunk metadata
chunk_metadata = {}
if hasattr(chunk, "meta") and chunk.meta:
chunk_metadata = {
"doc_items": (
[str(item) for item in chunk.meta.doc_items]
if hasattr(chunk.meta, "doc_items")
else []
),
"headings": (
chunk.meta.headings if hasattr(chunk.meta, "headings") else []
),
}
# Create chunk record (without embedding yet)
chunk_record = KnowledgeBaseChunkModel(
document_id=document_id,
organization_id=organization_id,
chunk_text=chunk_text,
contextualized_text=contextualized_text,
chunk_index=i,
chunk_metadata=chunk_metadata,
embedding_model=service.get_model_id(),
embedding_dimension=service.get_embedding_dimension(),
token_count=token_count,
)
chunk_records.append(chunk_record)
chunk_texts.append(text_to_tokenize)
# Log chunk statistics
if token_counts:
avg_tokens = sum(token_counts) / len(token_counts)
min_tokens = min(token_counts)
max_tokens_actual = max(token_counts)
logger.info("Chunk token statistics:")
logger.info(f" - Average: {avg_tokens:.1f} tokens")
logger.info(f" - Min: {min_tokens} tokens")
logger.info(f" - Max: {max_tokens_actual} tokens")
# Step 6: Generate embeddings using the embedding service
logger.info(f"Generating embeddings using {embedding_service}")
embeddings = await service.embed_texts(chunk_texts)
# Step 7: Attach embeddings to chunk records
for chunk_record, embedding in zip(chunk_records, embeddings):
chunk_record.embedding = embedding
# Step 8: Save chunks in database
logger.info("Storing chunks in database")
await db_client.create_chunks_batch(chunk_records)
# Step 9: Update document status to completed
await db_client.update_document_status(
document_id,
"completed",
total_chunks=total_chunks,
docling_metadata=docling_metadata,
)
logger.info(
f"Successfully processed knowledge base document {document_id}. "
f"Total chunks: {total_chunks}"
)
except Exception as e:
logger.error(
f"Error processing knowledge base document {document_id}: {e}",
exc_info=True,
)
# Update document status to failed
await db_client.update_document_status(
document_id, "failed", error_message=str(e)
)
raise
finally:
# Clean up temp file
if temp_file_path and os.path.exists(temp_file_path):
try:
os.remove(temp_file_path)
logger.debug(f"Cleaned up temp file: {temp_file_path}")
except Exception as e:
logger.warning(f"Failed to clean up temp file {temp_file_path}: {e}")