mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-28 08:49:42 +02:00
fix: enable knowledge base with Dograh config v2
This commit is contained in:
parent
d675fd1fda
commit
efb25a0cc5
19 changed files with 557 additions and 113 deletions
|
|
@ -12,12 +12,28 @@ from loguru import logger
|
|||
|
||||
from api.db import db_client
|
||||
from api.db.models import KnowledgeBaseChunkModel
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.gen_ai import AzureOpenAIEmbeddingService, OpenAIEmbeddingService
|
||||
from api.services.gen_ai import build_embedding_service
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
from api.services.storage import storage_fs
|
||||
|
||||
MAX_FILE_SIZE_BYTES = 5 * 1024 * 1024
|
||||
EMBEDDING_BATCH_SIZE = 64
|
||||
|
||||
|
||||
async def _embed_texts_in_batches(
|
||||
embedding_service,
|
||||
texts: list[str],
|
||||
batch_size: int = EMBEDDING_BATCH_SIZE,
|
||||
) -> list[list[float]]:
|
||||
"""Generate embeddings in bounded batches for provider/MPS stability."""
|
||||
embeddings: list[list[float]] = []
|
||||
for start in range(0, len(texts), batch_size):
|
||||
batch = texts[start : start + batch_size]
|
||||
logger.info(
|
||||
f"Generating embedding batch {start // batch_size + 1} ({len(batch)} texts)"
|
||||
)
|
||||
embeddings.extend(await embedding_service.embed_texts(batch))
|
||||
return embeddings
|
||||
|
||||
|
||||
async def process_knowledge_base_document(
|
||||
|
|
@ -121,42 +137,13 @@ async def process_knowledge_base_document(
|
|||
mime_type=mime_type,
|
||||
)
|
||||
|
||||
logger.info(f"Delegating document processing to MPS (mode={retrieval_mode})")
|
||||
mps_response = await mps_service_key_client.process_document(
|
||||
file_path=temp_file_path,
|
||||
filename=filename,
|
||||
content_type=mime_type or "application/octet-stream",
|
||||
retrieval_mode=retrieval_mode,
|
||||
max_tokens=max_tokens,
|
||||
organization_id=organization_id,
|
||||
created_by=created_by_provider_id,
|
||||
)
|
||||
|
||||
docling_metadata = mps_response.get("docling_metadata", {})
|
||||
|
||||
if retrieval_mode == "full_document":
|
||||
full_text = mps_response.get("full_text") or ""
|
||||
await db_client.update_document_full_text(document_id, full_text)
|
||||
await db_client.update_document_status(
|
||||
document_id,
|
||||
"completed",
|
||||
total_chunks=0,
|
||||
docling_metadata=docling_metadata,
|
||||
)
|
||||
logger.info(
|
||||
f"Successfully processed full_document {document_id}. "
|
||||
f"Text length: {len(full_text)} chars"
|
||||
)
|
||||
return
|
||||
|
||||
# Chunked mode: fetch user embedding config, embed, and persist chunks.
|
||||
embeddings_provider = None
|
||||
embeddings_api_key = None
|
||||
embeddings_model = None
|
||||
embeddings_base_url = None
|
||||
embeddings_endpoint = None
|
||||
embeddings_api_version = None
|
||||
if document.created_by:
|
||||
if retrieval_mode == "chunked" and document.created_by:
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
apply_managed_embeddings_base_url,
|
||||
get_resolved_ai_model_configuration,
|
||||
|
|
@ -188,6 +175,34 @@ async def process_knowledge_base_document(
|
|||
f"model={embeddings_model}"
|
||||
)
|
||||
|
||||
logger.info(f"Delegating document processing to MPS (mode={retrieval_mode})")
|
||||
mps_response = await mps_service_key_client.process_document(
|
||||
file_path=temp_file_path,
|
||||
filename=filename,
|
||||
content_type=mime_type or "application/octet-stream",
|
||||
retrieval_mode=retrieval_mode,
|
||||
max_tokens=max_tokens,
|
||||
organization_id=organization_id,
|
||||
created_by=created_by_provider_id,
|
||||
)
|
||||
|
||||
docling_metadata = mps_response.get("docling_metadata", {})
|
||||
|
||||
if retrieval_mode == "full_document":
|
||||
full_text = mps_response.get("full_text") or ""
|
||||
await db_client.update_document_full_text(document_id, full_text)
|
||||
await db_client.update_document_status(
|
||||
document_id,
|
||||
"completed",
|
||||
total_chunks=0,
|
||||
docling_metadata=docling_metadata,
|
||||
)
|
||||
logger.info(
|
||||
f"Successfully processed full_document {document_id}. "
|
||||
f"Text length: {len(full_text)} chars"
|
||||
)
|
||||
return
|
||||
|
||||
if not embeddings_api_key:
|
||||
error_message = (
|
||||
"API key not configured. Please set your API key in "
|
||||
|
|
@ -199,21 +214,20 @@ async def process_knowledge_base_document(
|
|||
)
|
||||
return
|
||||
|
||||
if embeddings_provider == ServiceProviders.AZURE.value and embeddings_endpoint:
|
||||
embedding_service = AzureOpenAIEmbeddingService(
|
||||
db_client=db_client,
|
||||
api_key=embeddings_api_key,
|
||||
endpoint=embeddings_endpoint,
|
||||
model_id=embeddings_model or "text-embedding-3-small",
|
||||
api_version=embeddings_api_version or "2024-02-15-preview",
|
||||
)
|
||||
else:
|
||||
embedding_service = OpenAIEmbeddingService(
|
||||
db_client=db_client,
|
||||
api_key=embeddings_api_key,
|
||||
model_id=embeddings_model or "text-embedding-3-small",
|
||||
base_url=embeddings_base_url,
|
||||
)
|
||||
# Ingestion runs outside any workflow run, so resolve the MPS correlation
|
||||
# id here (mint only for orgs already on v2; never create an account).
|
||||
embedding_service = await build_embedding_service(
|
||||
db_client=db_client,
|
||||
provider=embeddings_provider,
|
||||
api_key=embeddings_api_key,
|
||||
model=embeddings_model,
|
||||
base_url=embeddings_base_url,
|
||||
endpoint=embeddings_endpoint,
|
||||
api_version=embeddings_api_version,
|
||||
organization_id=organization_id,
|
||||
created_by=created_by_provider_id,
|
||||
resolve_correlation=True,
|
||||
)
|
||||
|
||||
mps_chunks = mps_response.get("chunks", [])
|
||||
if not mps_chunks:
|
||||
|
|
@ -242,12 +256,21 @@ async def process_knowledge_base_document(
|
|||
f"Generating embeddings for {len(chunk_texts)} chunks "
|
||||
f"using {embedding_service.get_model_id()}"
|
||||
)
|
||||
embeddings = await embedding_service.embed_texts(chunk_texts)
|
||||
embeddings = await _embed_texts_in_batches(embedding_service, chunk_texts)
|
||||
if len(embeddings) != len(chunk_records):
|
||||
raise ValueError(
|
||||
"Embedding count mismatch: "
|
||||
f"expected {len(chunk_records)}, got {len(embeddings)}"
|
||||
)
|
||||
for chunk_record, embedding in zip(chunk_records, embeddings):
|
||||
chunk_record.embedding = embedding
|
||||
|
||||
logger.info("Storing chunks in database")
|
||||
await db_client.create_chunks_batch(chunk_records)
|
||||
await db_client.replace_chunks_for_document(
|
||||
document_id=document_id,
|
||||
organization_id=organization_id,
|
||||
chunks=chunk_records,
|
||||
)
|
||||
|
||||
await db_client.update_document_status(
|
||||
document_id,
|
||||
|
|
@ -262,9 +285,8 @@ async def process_knowledge_base_document(
|
|||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing knowledge base document {document_id}: {e}",
|
||||
exc_info=True,
|
||||
logger.exception(
|
||||
"Error processing knowledge base document {}: {}", document_id, e
|
||||
)
|
||||
await db_client.update_document_status(
|
||||
document_id, "failed", error_message=str(e)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue