dograh/api/tasks/knowledge_base_processing.py
Abhishek 00a1a22b74
feat: refactor node spec and add mcp tools (#244)
* refactor: carve out extraction panel

* refactor: create spec versions for node types

* refactor: create a GenericNode and remove custom nodes

* feat: add python and typescript sdk

* add dograh sdk

* fix: fetch draft workflow definition over published one

* fix: fix routes of SDKs to use code gen

* chore: remove doclink dependency to reduce image size

* chore: format files

* chore: bump pipecat

* feat: let mcp fetch archived workflows on demand

* chore: fix tests

* feat: add sdk documentation

* chore: change banner and add badge
2026-04-21 07:56:16 +05:30

243 lines
9.3 KiB
Python

"""ARQ background task for processing knowledge base documents.
Document conversion and chunking live in the Model Proxy Service (MPS);
this task downloads the file from S3, calls MPS, then handles the embedding
and DB writes locally.
"""
import os
import tempfile
from loguru import logger
from api.db import db_client
from api.db.models import KnowledgeBaseChunkModel
from api.services.gen_ai import OpenAIEmbeddingService
from api.services.mps_service_key_client import mps_service_key_client
from api.services.storage import storage_fs
MAX_FILE_SIZE_BYTES = 5 * 1024 * 1024
async def process_knowledge_base_document(
ctx,
document_id: int,
s3_key: str,
organization_id: int,
created_by_provider_id: str,
max_tokens: int = 128,
retrieval_mode: str = "chunked",
):
"""Process a knowledge base document via MPS: download, call MPS, embed, store.
Args:
ctx: ARQ context
document_id: Database ID of the document
s3_key: S3 key where the file is stored
organization_id: Organization ID
created_by_provider_id: Uploading user's provider ID (for OSS-mode auth to MPS)
max_tokens: Maximum number of tokens per chunk (default: 128)
retrieval_mode: "chunked" for vector search or "full_document" for full text
"""
logger.info(
f"Processing knowledge base document: document_id={document_id}, "
f"s3_key={s3_key}, org={organization_id}, mode={retrieval_mode}"
)
temp_file_path = None
try:
await db_client.update_document_status(document_id, "processing")
filename = s3_key.split("/")[-1]
file_extension = os.path.splitext(filename)[1] or ".bin"
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=file_extension)
temp_file_path = temp_file.name
temp_file.close()
logger.info(f"Downloading file from S3: {s3_key}")
download_success = await storage_fs.adownload_file(s3_key, temp_file_path)
if not download_success:
raise Exception(f"Failed to download file from S3: {s3_key}")
if not os.path.exists(temp_file_path):
raise FileNotFoundError(f"Downloaded file not found: {temp_file_path}")
file_size = os.path.getsize(temp_file_path)
logger.info(f"Downloaded file size: {file_size} bytes")
if file_size > MAX_FILE_SIZE_BYTES:
error_message = (
f"File size ({file_size / (1024 * 1024):.1f}MB) exceeds the "
f"maximum allowed size of {MAX_FILE_SIZE_BYTES // (1024 * 1024)}MB."
)
logger.warning(f"Document {document_id}: {error_message}")
await db_client.update_document_status(
document_id, "failed", error_message=error_message
)
return
file_hash = db_client.compute_file_hash(temp_file_path)
mime_type = db_client.get_mime_type(temp_file_path)
document = await db_client.get_document_by_id(document_id)
if not document:
raise Exception(f"Document {document_id} not found")
# Reject duplicates (same hash already ingested for this org).
existing_doc = await db_client.get_document_by_hash(file_hash, organization_id)
if existing_doc and existing_doc.id != document_id:
error_message = (
f"This file is a duplicate of '{existing_doc.filename}'. "
f"Please delete the duplicate files and consolidate them into a "
f"single unique file before uploading."
)
logger.warning(
f"Duplicate document detected: {document_id} is duplicate of "
f"{existing_doc.id} ({existing_doc.filename})"
)
await db_client.update_document_metadata(
document_id,
file_size_bytes=file_size,
file_hash=file_hash,
mime_type=mime_type,
)
await db_client.update_document_status(
document_id,
"failed",
error_message=error_message,
docling_metadata={
"duplicate_of": existing_doc.document_uuid,
"duplicate_filename": existing_doc.filename,
},
)
return
await db_client.update_document_metadata(
document_id,
file_size_bytes=file_size,
file_hash=file_hash,
mime_type=mime_type,
)
logger.info(f"Delegating document processing to MPS (mode={retrieval_mode})")
mps_response = await mps_service_key_client.process_document(
file_path=temp_file_path,
filename=filename,
content_type=mime_type or "application/octet-stream",
retrieval_mode=retrieval_mode,
max_tokens=max_tokens,
organization_id=organization_id,
created_by=created_by_provider_id,
)
docling_metadata = mps_response.get("docling_metadata", {})
if retrieval_mode == "full_document":
full_text = mps_response.get("full_text") or ""
await db_client.update_document_full_text(document_id, full_text)
await db_client.update_document_status(
document_id,
"completed",
total_chunks=0,
docling_metadata=docling_metadata,
)
logger.info(
f"Successfully processed full_document {document_id}. "
f"Text length: {len(full_text)} chars"
)
return
# Chunked mode: fetch user embedding config, embed via OpenAI, persist chunks.
embeddings_api_key = None
embeddings_model = None
embeddings_base_url = None
if document.created_by:
user_config = await db_client.get_user_configurations(document.created_by)
if user_config.embeddings:
embeddings_api_key = user_config.embeddings.api_key
embeddings_model = user_config.embeddings.model
embeddings_base_url = getattr(user_config.embeddings, "base_url", None)
logger.info(f"Using user embeddings config: model={embeddings_model}")
if not embeddings_api_key:
error_message = (
"OpenAI API key not configured. Please set your API key in "
"Model Configurations > Embedding to process documents."
)
logger.warning(f"Document {document_id}: {error_message}")
await db_client.update_document_status(
document_id, "failed", error_message=error_message
)
return
embedding_service = OpenAIEmbeddingService(
db_client=db_client,
api_key=embeddings_api_key,
model_id=embeddings_model or "text-embedding-3-small",
base_url=embeddings_base_url,
)
mps_chunks = mps_response.get("chunks", [])
if not mps_chunks:
logger.warning(f"Document {document_id}: MPS returned zero chunks")
chunk_records = []
chunk_texts = []
for chunk in mps_chunks:
contextualized = chunk.get("contextualized_text") or chunk["chunk_text"]
chunk_records.append(
KnowledgeBaseChunkModel(
document_id=document_id,
organization_id=organization_id,
chunk_text=chunk["chunk_text"],
contextualized_text=contextualized,
chunk_index=chunk["chunk_index"],
chunk_metadata=chunk.get("chunk_metadata") or {},
embedding_model=embedding_service.get_model_id(),
embedding_dimension=embedding_service.get_embedding_dimension(),
token_count=chunk.get("token_count", 0),
)
)
chunk_texts.append(contextualized)
logger.info(
f"Generating embeddings for {len(chunk_texts)} chunks "
f"using {embedding_service.get_model_id()}"
)
embeddings = await embedding_service.embed_texts(chunk_texts)
for chunk_record, embedding in zip(chunk_records, embeddings):
chunk_record.embedding = embedding
logger.info("Storing chunks in database")
await db_client.create_chunks_batch(chunk_records)
await db_client.update_document_status(
document_id,
"completed",
total_chunks=len(chunk_records),
docling_metadata=docling_metadata,
)
logger.info(
f"Successfully processed knowledge base document {document_id}. "
f"Total chunks: {len(chunk_records)}"
)
except Exception as e:
logger.error(
f"Error processing knowledge base document {document_id}: {e}",
exc_info=True,
)
await db_client.update_document_status(
document_id, "failed", error_message=str(e)
)
raise
finally:
if temp_file_path and os.path.exists(temp_file_path):
try:
os.remove(temp_file_path)
logger.debug(f"Cleaned up temp file: {temp_file_path}")
except Exception as e:
logger.warning(f"Failed to clean up temp file {temp_file_path}: {e}")