mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-28 08:49:42 +02:00
fix: enable knowledge base with Dograh config v2
This commit is contained in:
parent
d675fd1fda
commit
efb25a0cc5
19 changed files with 557 additions and 113 deletions
|
|
@ -5,7 +5,7 @@ from pathlib import Path
|
|||
from typing import List, Optional
|
||||
|
||||
from loguru import logger
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from api.db.base_client import BaseDBClient
|
||||
|
|
@ -300,6 +300,31 @@ class KnowledgeBaseClient(BaseDBClient):
|
|||
logger.info(f"Created {len(chunks)} chunks")
|
||||
return chunks
|
||||
|
||||
async def replace_chunks_for_document(
|
||||
self,
|
||||
document_id: int,
|
||||
organization_id: int,
|
||||
chunks: List[KnowledgeBaseChunkModel],
|
||||
) -> List[KnowledgeBaseChunkModel]:
|
||||
"""Replace all chunks for a document with a new precomputed batch."""
|
||||
async with self.async_session() as session:
|
||||
await session.execute(
|
||||
delete(KnowledgeBaseChunkModel).where(
|
||||
KnowledgeBaseChunkModel.document_id == document_id,
|
||||
KnowledgeBaseChunkModel.organization_id == organization_id,
|
||||
)
|
||||
)
|
||||
session.add_all(chunks)
|
||||
await session.commit()
|
||||
|
||||
for chunk in chunks:
|
||||
await session.refresh(chunk)
|
||||
|
||||
logger.info(
|
||||
f"Replaced chunks for document {document_id}: {len(chunks)} chunks"
|
||||
)
|
||||
return chunks
|
||||
|
||||
async def get_chunks_for_document(
|
||||
self,
|
||||
document_id: int,
|
||||
|
|
|
|||
|
|
@ -373,11 +373,7 @@ async def search_chunks(
|
|||
apply_managed_embeddings_base_url,
|
||||
get_resolved_ai_model_configuration,
|
||||
)
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.gen_ai import (
|
||||
AzureOpenAIEmbeddingService,
|
||||
OpenAIEmbeddingService,
|
||||
)
|
||||
from api.services.gen_ai import build_embedding_service
|
||||
|
||||
# Try to get user's embeddings configuration
|
||||
resolved_config = await get_resolved_ai_model_configuration(
|
||||
|
|
@ -405,22 +401,20 @@ async def search_chunks(
|
|||
effective_config.embeddings, "api_version", None
|
||||
)
|
||||
|
||||
# Initialize embedding service based on provider
|
||||
if embeddings_provider == ServiceProviders.AZURE.value and embeddings_endpoint:
|
||||
embedding_service = AzureOpenAIEmbeddingService(
|
||||
db_client=db_client,
|
||||
api_key=embeddings_api_key,
|
||||
endpoint=embeddings_endpoint,
|
||||
model_id=embeddings_model or "text-embedding-3-small",
|
||||
api_version=embeddings_api_version or "2024-02-15-preview",
|
||||
)
|
||||
else:
|
||||
embedding_service = OpenAIEmbeddingService(
|
||||
db_client=db_client,
|
||||
api_key=embeddings_api_key,
|
||||
model_id=embeddings_model or "text-embedding-3-small",
|
||||
base_url=embeddings_base_url,
|
||||
)
|
||||
# Manual search runs outside any workflow run, so resolve the MPS
|
||||
# correlation id here (mint only for orgs already on v2; never create one).
|
||||
embedding_service = await build_embedding_service(
|
||||
db_client=db_client,
|
||||
provider=embeddings_provider,
|
||||
api_key=embeddings_api_key,
|
||||
model=embeddings_model,
|
||||
base_url=embeddings_base_url,
|
||||
endpoint=embeddings_endpoint,
|
||||
api_version=embeddings_api_version,
|
||||
organization_id=user.selected_organization_id,
|
||||
created_by=str(user.provider_id),
|
||||
resolve_correlation=True,
|
||||
)
|
||||
|
||||
# Perform search
|
||||
results = await embedding_service.search_similar_chunks(
|
||||
|
|
|
|||
|
|
@ -176,7 +176,7 @@ def _compile_dograh_configuration(
|
|||
embeddings=DograhEmbeddingsConfiguration(
|
||||
provider=ServiceProviders.DOGRAH,
|
||||
api_key=configuration.api_key,
|
||||
model="default",
|
||||
model="dograh_embedding_v1",
|
||||
),
|
||||
is_realtime=False,
|
||||
managed_service_version=2,
|
||||
|
|
|
|||
|
|
@ -457,6 +457,11 @@ async def create_user_configuration_with_mps_key(
|
|||
"api_key": [service_key],
|
||||
"model": "default",
|
||||
},
|
||||
"embeddings": {
|
||||
"provider": ServiceProviders.DOGRAH.value,
|
||||
"api_key": [service_key],
|
||||
"model": "dograh_embedding_v1",
|
||||
},
|
||||
}
|
||||
effective_config = EffectiveAIModelConfiguration(**configuration)
|
||||
return effective_config
|
||||
|
|
|
|||
|
|
@ -316,6 +316,7 @@ def convert_legacy_ai_model_configuration_to_v2(
|
|||
|
||||
|
||||
def dograh_embeddings_base_url() -> str:
|
||||
# AsyncOpenAI appends "/embeddings"; MPS exposes that under /api/v1/llm.
|
||||
return f"{MPS_API_URL}/api/v1/llm"
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1726,7 +1726,7 @@ class AzureOpenAIEmbeddingsConfiguration(BaseEmbeddingsConfiguration):
|
|||
)
|
||||
|
||||
|
||||
DOGRAH_EMBEDDING_MODELS = ["default"]
|
||||
DOGRAH_EMBEDDING_MODELS = ["dograh_embedding_v1"]
|
||||
|
||||
|
||||
@register_embeddings
|
||||
|
|
@ -1734,7 +1734,7 @@ class DograhEmbeddingsConfiguration(BaseEmbeddingsConfiguration):
|
|||
model_config = DOGRAH_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.DOGRAH] = ServiceProviders.DOGRAH
|
||||
model: str = Field(
|
||||
default="default",
|
||||
default="dograh_embedding_v1",
|
||||
description="Dograh-managed embedding model.",
|
||||
json_schema_extra={"examples": DOGRAH_EMBEDDING_MODELS},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,8 +4,11 @@ from .embedding import (
|
|||
AzureEmbeddingAPIKeyNotConfiguredError,
|
||||
AzureOpenAIEmbeddingService,
|
||||
BaseEmbeddingService,
|
||||
DograhEmbeddingService,
|
||||
EmbeddingAPIKeyNotConfiguredError,
|
||||
OpenAIEmbeddingService,
|
||||
build_embedding_service,
|
||||
resolve_embedding_correlation_id,
|
||||
)
|
||||
from .json_parser import parse_llm_json
|
||||
|
||||
|
|
@ -13,7 +16,10 @@ __all__ = [
|
|||
"AzureEmbeddingAPIKeyNotConfiguredError",
|
||||
"AzureOpenAIEmbeddingService",
|
||||
"BaseEmbeddingService",
|
||||
"DograhEmbeddingService",
|
||||
"EmbeddingAPIKeyNotConfiguredError",
|
||||
"OpenAIEmbeddingService",
|
||||
"build_embedding_service",
|
||||
"resolve_embedding_correlation_id",
|
||||
"parse_llm_json",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -5,12 +5,17 @@ from .azure_openai_service import (
|
|||
AzureOpenAIEmbeddingService,
|
||||
)
|
||||
from .base import BaseEmbeddingService
|
||||
from .dograh_service import DograhEmbeddingService
|
||||
from .factory import build_embedding_service, resolve_embedding_correlation_id
|
||||
from .openai_service import EmbeddingAPIKeyNotConfiguredError, OpenAIEmbeddingService
|
||||
|
||||
__all__ = [
|
||||
"AzureEmbeddingAPIKeyNotConfiguredError",
|
||||
"AzureOpenAIEmbeddingService",
|
||||
"BaseEmbeddingService",
|
||||
"DograhEmbeddingService",
|
||||
"EmbeddingAPIKeyNotConfiguredError",
|
||||
"OpenAIEmbeddingService",
|
||||
"build_embedding_service",
|
||||
"resolve_embedding_correlation_id",
|
||||
]
|
||||
|
|
|
|||
69
api/services/gen_ai/embedding/dograh_service.py
Normal file
69
api/services/gen_ai/embedding/dograh_service.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
"""Dograh-managed embedding service.
|
||||
|
||||
Routes embeddings through Dograh's managed proxy (MPS). This mirrors the managed
|
||||
voice services (``DograhLLMService`` / ``DograhTTSService``): when a server-minted
|
||||
MPS correlation id is present, it forwards the MPS billing v2 protocol
|
||||
(``correlation_id`` + ``mps_billing_version``) in the request body so MPS can
|
||||
authorize and attribute the call. With no correlation id (e.g. a v1 org) it
|
||||
behaves like a plain OpenAI-compatible call, which MPS accepts.
|
||||
|
||||
Keeping this in a subclass keeps ``OpenAIEmbeddingService`` a generic
|
||||
OpenAI-compatible client; only the managed path carries MPS-specific metadata,
|
||||
so BYOK OpenAI/Azure requests never ship MPS fields to the real provider.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from api.db.db_client import DBClient
|
||||
|
||||
from .openai_service import DEFAULT_MODEL_ID, OpenAIEmbeddingService
|
||||
|
||||
# Protocol contract with MPS (see model_services
|
||||
# api/services/model_service_correlations.py). Kept local to avoid coupling the
|
||||
# app layer to the pipecat package, which defines its own copy for voice.
|
||||
MPS_BILLING_VERSION_KEY = "mps_billing_version"
|
||||
MPS_BILLING_VERSION_V2 = "2"
|
||||
|
||||
|
||||
class DograhEmbeddingService(OpenAIEmbeddingService):
|
||||
"""OpenAI-compatible embedding client pointed at Dograh's managed proxy."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_client: DBClient,
|
||||
api_key: Optional[str] = None,
|
||||
model_id: str = DEFAULT_MODEL_ID,
|
||||
base_url: Optional[str] = None,
|
||||
correlation_id: Optional[str] = None,
|
||||
):
|
||||
"""Initialize the managed embedding service.
|
||||
|
||||
Args:
|
||||
db_client: Database client for vector similarity search.
|
||||
api_key: Dograh-managed MPS service key.
|
||||
model_id: Embedding model/tier id (default: text-embedding-3-small).
|
||||
base_url: MPS embeddings base URL.
|
||||
correlation_id: Server-minted MPS correlation id. When set, the MPS
|
||||
billing v2 protocol is forwarded with each request. When None,
|
||||
requests are sent without the protocol (valid for v1 orgs).
|
||||
"""
|
||||
super().__init__(
|
||||
db_client=db_client,
|
||||
api_key=api_key,
|
||||
model_id=model_id,
|
||||
base_url=base_url,
|
||||
)
|
||||
self._correlation_id = correlation_id
|
||||
|
||||
def _request_kwargs(self) -> Dict[str, Any]:
|
||||
"""Forward the MPS billing v2 protocol when a correlation id is present."""
|
||||
if not self._correlation_id:
|
||||
return {}
|
||||
return {
|
||||
"extra_body": {
|
||||
"metadata": {
|
||||
"correlation_id": self._correlation_id,
|
||||
MPS_BILLING_VERSION_KEY: MPS_BILLING_VERSION_V2,
|
||||
}
|
||||
}
|
||||
}
|
||||
137
api/services/gen_ai/embedding/factory.py
Normal file
137
api/services/gen_ai/embedding/factory.py
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
"""Factory for embedding services, including the Dograh-managed (MPS) path.
|
||||
|
||||
Centralizes the provider branching (Azure BYOK / Dograh-managed / OpenAI-compatible
|
||||
BYOK) that was previously duplicated across document ingestion, the search route,
|
||||
and the RAG tool, and resolves the MPS billing v2 protocol the same way the voice
|
||||
path does: attach it only for orgs already on v2, and never create a billing
|
||||
account to do so.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.db.db_client import DBClient
|
||||
|
||||
from .azure_openai_service import AzureOpenAIEmbeddingService
|
||||
from .base import BaseEmbeddingService
|
||||
from .dograh_service import DograhEmbeddingService
|
||||
from .openai_service import OpenAIEmbeddingService
|
||||
|
||||
DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small"
|
||||
DEFAULT_AZURE_API_VERSION = "2024-02-15-preview"
|
||||
|
||||
|
||||
async def resolve_embedding_correlation_id(
|
||||
*,
|
||||
organization_id: Optional[int],
|
||||
service_key: Optional[str],
|
||||
created_by: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""Resolve an MPS correlation id for a managed embedding call made outside a run.
|
||||
|
||||
Mirrors the voice path's gating:
|
||||
|
||||
- OSS deployments use a pasted hosted v2 key (v2 by definition), so mint
|
||||
directly via the bearer endpoint — matching ``_authorize_oss_managed_v2_correlation``.
|
||||
- Hosted/SaaS: read the org's billing mode (no side effects) and mint only when
|
||||
it is already v2. Minting for an already-v2 org is a no-op on the account.
|
||||
|
||||
Returns ``None`` when the call should be sent without the protocol; MPS accepts
|
||||
un-gated embedding calls from v1 orgs. Never creates a v2 billing account.
|
||||
"""
|
||||
if not service_key:
|
||||
return None
|
||||
|
||||
# Imported lazily to avoid import-time cycles between the gen_ai and service
|
||||
# layers (matches the inline-import convention used elsewhere in the app).
|
||||
from api.constants import DEPLOYMENT_MODE
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
|
||||
try:
|
||||
if DEPLOYMENT_MODE == "oss":
|
||||
minted = await mps_service_key_client.create_correlation_id(
|
||||
service_key=service_key
|
||||
)
|
||||
return minted.get("correlation_id")
|
||||
|
||||
if organization_id is None:
|
||||
return None
|
||||
|
||||
status = await mps_service_key_client.get_billing_account_status(
|
||||
organization_id, created_by=created_by
|
||||
)
|
||||
if not status or status.get("billing_mode") != "v2":
|
||||
return None
|
||||
|
||||
minted = await mps_service_key_client.create_correlation_id(
|
||||
service_key=service_key
|
||||
)
|
||||
return minted.get("correlation_id")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Could not resolve MPS correlation id for managed embeddings; "
|
||||
"sending without v2 protocol: {}",
|
||||
e,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def build_embedding_service(
|
||||
*,
|
||||
db_client: DBClient,
|
||||
provider: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: Optional[str],
|
||||
base_url: Optional[str] = None,
|
||||
endpoint: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
correlation_id: Optional[str] = None,
|
||||
organization_id: Optional[int] = None,
|
||||
created_by: Optional[str] = None,
|
||||
resolve_correlation: bool = False,
|
||||
) -> BaseEmbeddingService:
|
||||
"""Construct the right embedding service for a provider/config.
|
||||
|
||||
Args:
|
||||
correlation_id: A correlation id already available in context (e.g. the
|
||||
running workflow's MPS correlation id). Used for the Dograh provider.
|
||||
resolve_correlation: When True and no ``correlation_id`` is supplied, resolve
|
||||
one for the Dograh provider via ``resolve_embedding_correlation_id``
|
||||
(for calls made outside a workflow run: ingestion, manual search).
|
||||
"""
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
|
||||
model_id = model or DEFAULT_EMBEDDING_MODEL
|
||||
|
||||
if provider == ServiceProviders.AZURE.value and endpoint:
|
||||
return AzureOpenAIEmbeddingService(
|
||||
db_client=db_client,
|
||||
api_key=api_key,
|
||||
endpoint=endpoint,
|
||||
model_id=model_id,
|
||||
api_version=api_version or DEFAULT_AZURE_API_VERSION,
|
||||
)
|
||||
|
||||
if provider == ServiceProviders.DOGRAH.value:
|
||||
cid = correlation_id
|
||||
if cid is None and resolve_correlation:
|
||||
cid = await resolve_embedding_correlation_id(
|
||||
organization_id=organization_id,
|
||||
service_key=api_key,
|
||||
created_by=created_by,
|
||||
)
|
||||
return DograhEmbeddingService(
|
||||
db_client=db_client,
|
||||
api_key=api_key,
|
||||
model_id=model_id,
|
||||
base_url=base_url,
|
||||
correlation_id=cid,
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingService(
|
||||
db_client=db_client,
|
||||
api_key=api_key,
|
||||
model_id=model_id,
|
||||
base_url=base_url,
|
||||
)
|
||||
|
|
@ -85,6 +85,14 @@ class OpenAIEmbeddingService(BaseEmbeddingService):
|
|||
if not self._api_key_configured or self.client is None:
|
||||
raise EmbeddingAPIKeyNotConfiguredError()
|
||||
|
||||
def _request_kwargs(self) -> Dict[str, Any]:
|
||||
"""Extra kwargs merged into every embeddings.create() call.
|
||||
|
||||
Override hook for subclasses (e.g. DograhEmbeddingService injects the MPS
|
||||
billing protocol here). The base service adds nothing.
|
||||
"""
|
||||
return {}
|
||||
|
||||
async def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed a batch of texts using OpenAI API.
|
||||
|
||||
|
|
@ -97,6 +105,7 @@ class OpenAIEmbeddingService(BaseEmbeddingService):
|
|||
response = await self.client.embeddings.create(
|
||||
input=texts,
|
||||
model=self.model_id,
|
||||
**self._request_kwargs(),
|
||||
)
|
||||
return [item.embedding for item in response.data]
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -41,9 +41,7 @@ async def _persist_amd_result_if_present(
|
|||
gathered_context={"answered_by": amd_result.answered_by},
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
f"[run {workflow_run_id}] Failed to persist AMD result: {exc}"
|
||||
)
|
||||
logger.warning(f"[run {workflow_run_id}] Failed to persist AMD result: {exc}")
|
||||
|
||||
|
||||
@router.post("/twiml", include_in_schema=False)
|
||||
|
|
|
|||
|
|
@ -13,8 +13,7 @@ from loguru import logger
|
|||
from opentelemetry import trace
|
||||
|
||||
from api.db import db_client
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.gen_ai import AzureOpenAIEmbeddingService, OpenAIEmbeddingService
|
||||
from api.services.gen_ai import build_embedding_service
|
||||
from api.services.pipecat.tracing_config import ensure_tracing
|
||||
|
||||
|
||||
|
|
@ -266,33 +265,19 @@ async def _perform_retrieval(
|
|||
"Model Configurations > Embedding."
|
||||
)
|
||||
|
||||
if (
|
||||
embeddings_provider == ServiceProviders.AZURE.value
|
||||
and embeddings_endpoint
|
||||
):
|
||||
embedding_service = AzureOpenAIEmbeddingService(
|
||||
db_client=db_client,
|
||||
api_key=embeddings_api_key,
|
||||
endpoint=embeddings_endpoint,
|
||||
model_id=embeddings_model or "text-embedding-3-small",
|
||||
api_version=embeddings_api_version or "2024-02-15-preview",
|
||||
)
|
||||
else:
|
||||
default_headers = None
|
||||
if (
|
||||
embeddings_provider == ServiceProviders.DOGRAH.value
|
||||
and correlation_id
|
||||
):
|
||||
default_headers = {
|
||||
"X-Dograh-Correlation-Id": correlation_id,
|
||||
}
|
||||
embedding_service = OpenAIEmbeddingService(
|
||||
db_client=db_client,
|
||||
api_key=embeddings_api_key,
|
||||
model_id=embeddings_model or "text-embedding-3-small",
|
||||
base_url=embeddings_base_url,
|
||||
default_headers=default_headers,
|
||||
)
|
||||
# Search runs inside a workflow run: reuse the run's MPS correlation
|
||||
# id (present only for v2 orgs; None otherwise → sent without the
|
||||
# protocol). The Dograh-managed path forwards it via request metadata.
|
||||
embedding_service = await build_embedding_service(
|
||||
db_client=db_client,
|
||||
provider=embeddings_provider,
|
||||
api_key=embeddings_api_key,
|
||||
model=embeddings_model,
|
||||
base_url=embeddings_base_url,
|
||||
endpoint=embeddings_endpoint,
|
||||
api_version=embeddings_api_version,
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
|
||||
results = await embedding_service.search_similar_chunks(
|
||||
query=query,
|
||||
|
|
|
|||
|
|
@ -12,12 +12,28 @@ from loguru import logger
|
|||
|
||||
from api.db import db_client
|
||||
from api.db.models import KnowledgeBaseChunkModel
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.gen_ai import AzureOpenAIEmbeddingService, OpenAIEmbeddingService
|
||||
from api.services.gen_ai import build_embedding_service
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
from api.services.storage import storage_fs
|
||||
|
||||
MAX_FILE_SIZE_BYTES = 5 * 1024 * 1024
|
||||
EMBEDDING_BATCH_SIZE = 64
|
||||
|
||||
|
||||
async def _embed_texts_in_batches(
|
||||
embedding_service,
|
||||
texts: list[str],
|
||||
batch_size: int = EMBEDDING_BATCH_SIZE,
|
||||
) -> list[list[float]]:
|
||||
"""Generate embeddings in bounded batches for provider/MPS stability."""
|
||||
embeddings: list[list[float]] = []
|
||||
for start in range(0, len(texts), batch_size):
|
||||
batch = texts[start : start + batch_size]
|
||||
logger.info(
|
||||
f"Generating embedding batch {start // batch_size + 1} ({len(batch)} texts)"
|
||||
)
|
||||
embeddings.extend(await embedding_service.embed_texts(batch))
|
||||
return embeddings
|
||||
|
||||
|
||||
async def process_knowledge_base_document(
|
||||
|
|
@ -121,42 +137,13 @@ async def process_knowledge_base_document(
|
|||
mime_type=mime_type,
|
||||
)
|
||||
|
||||
logger.info(f"Delegating document processing to MPS (mode={retrieval_mode})")
|
||||
mps_response = await mps_service_key_client.process_document(
|
||||
file_path=temp_file_path,
|
||||
filename=filename,
|
||||
content_type=mime_type or "application/octet-stream",
|
||||
retrieval_mode=retrieval_mode,
|
||||
max_tokens=max_tokens,
|
||||
organization_id=organization_id,
|
||||
created_by=created_by_provider_id,
|
||||
)
|
||||
|
||||
docling_metadata = mps_response.get("docling_metadata", {})
|
||||
|
||||
if retrieval_mode == "full_document":
|
||||
full_text = mps_response.get("full_text") or ""
|
||||
await db_client.update_document_full_text(document_id, full_text)
|
||||
await db_client.update_document_status(
|
||||
document_id,
|
||||
"completed",
|
||||
total_chunks=0,
|
||||
docling_metadata=docling_metadata,
|
||||
)
|
||||
logger.info(
|
||||
f"Successfully processed full_document {document_id}. "
|
||||
f"Text length: {len(full_text)} chars"
|
||||
)
|
||||
return
|
||||
|
||||
# Chunked mode: fetch user embedding config, embed, and persist chunks.
|
||||
embeddings_provider = None
|
||||
embeddings_api_key = None
|
||||
embeddings_model = None
|
||||
embeddings_base_url = None
|
||||
embeddings_endpoint = None
|
||||
embeddings_api_version = None
|
||||
if document.created_by:
|
||||
if retrieval_mode == "chunked" and document.created_by:
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
apply_managed_embeddings_base_url,
|
||||
get_resolved_ai_model_configuration,
|
||||
|
|
@ -188,6 +175,34 @@ async def process_knowledge_base_document(
|
|||
f"model={embeddings_model}"
|
||||
)
|
||||
|
||||
logger.info(f"Delegating document processing to MPS (mode={retrieval_mode})")
|
||||
mps_response = await mps_service_key_client.process_document(
|
||||
file_path=temp_file_path,
|
||||
filename=filename,
|
||||
content_type=mime_type or "application/octet-stream",
|
||||
retrieval_mode=retrieval_mode,
|
||||
max_tokens=max_tokens,
|
||||
organization_id=organization_id,
|
||||
created_by=created_by_provider_id,
|
||||
)
|
||||
|
||||
docling_metadata = mps_response.get("docling_metadata", {})
|
||||
|
||||
if retrieval_mode == "full_document":
|
||||
full_text = mps_response.get("full_text") or ""
|
||||
await db_client.update_document_full_text(document_id, full_text)
|
||||
await db_client.update_document_status(
|
||||
document_id,
|
||||
"completed",
|
||||
total_chunks=0,
|
||||
docling_metadata=docling_metadata,
|
||||
)
|
||||
logger.info(
|
||||
f"Successfully processed full_document {document_id}. "
|
||||
f"Text length: {len(full_text)} chars"
|
||||
)
|
||||
return
|
||||
|
||||
if not embeddings_api_key:
|
||||
error_message = (
|
||||
"API key not configured. Please set your API key in "
|
||||
|
|
@ -199,21 +214,20 @@ async def process_knowledge_base_document(
|
|||
)
|
||||
return
|
||||
|
||||
if embeddings_provider == ServiceProviders.AZURE.value and embeddings_endpoint:
|
||||
embedding_service = AzureOpenAIEmbeddingService(
|
||||
db_client=db_client,
|
||||
api_key=embeddings_api_key,
|
||||
endpoint=embeddings_endpoint,
|
||||
model_id=embeddings_model or "text-embedding-3-small",
|
||||
api_version=embeddings_api_version or "2024-02-15-preview",
|
||||
)
|
||||
else:
|
||||
embedding_service = OpenAIEmbeddingService(
|
||||
db_client=db_client,
|
||||
api_key=embeddings_api_key,
|
||||
model_id=embeddings_model or "text-embedding-3-small",
|
||||
base_url=embeddings_base_url,
|
||||
)
|
||||
# Ingestion runs outside any workflow run, so resolve the MPS correlation
|
||||
# id here (mint only for orgs already on v2; never create an account).
|
||||
embedding_service = await build_embedding_service(
|
||||
db_client=db_client,
|
||||
provider=embeddings_provider,
|
||||
api_key=embeddings_api_key,
|
||||
model=embeddings_model,
|
||||
base_url=embeddings_base_url,
|
||||
endpoint=embeddings_endpoint,
|
||||
api_version=embeddings_api_version,
|
||||
organization_id=organization_id,
|
||||
created_by=created_by_provider_id,
|
||||
resolve_correlation=True,
|
||||
)
|
||||
|
||||
mps_chunks = mps_response.get("chunks", [])
|
||||
if not mps_chunks:
|
||||
|
|
@ -242,12 +256,21 @@ async def process_knowledge_base_document(
|
|||
f"Generating embeddings for {len(chunk_texts)} chunks "
|
||||
f"using {embedding_service.get_model_id()}"
|
||||
)
|
||||
embeddings = await embedding_service.embed_texts(chunk_texts)
|
||||
embeddings = await _embed_texts_in_batches(embedding_service, chunk_texts)
|
||||
if len(embeddings) != len(chunk_records):
|
||||
raise ValueError(
|
||||
"Embedding count mismatch: "
|
||||
f"expected {len(chunk_records)}, got {len(embeddings)}"
|
||||
)
|
||||
for chunk_record, embedding in zip(chunk_records, embeddings):
|
||||
chunk_record.embedding = embedding
|
||||
|
||||
logger.info("Storing chunks in database")
|
||||
await db_client.create_chunks_batch(chunk_records)
|
||||
await db_client.replace_chunks_for_document(
|
||||
document_id=document_id,
|
||||
organization_id=organization_id,
|
||||
chunks=chunk_records,
|
||||
)
|
||||
|
||||
await db_client.update_document_status(
|
||||
document_id,
|
||||
|
|
@ -262,9 +285,8 @@ async def process_knowledge_base_document(
|
|||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing knowledge base document {document_id}: {e}",
|
||||
exc_info=True,
|
||||
logger.exception(
|
||||
"Error processing knowledge base document {}: {}", document_id, e
|
||||
)
|
||||
await db_client.update_document_status(
|
||||
document_id, "failed", error_message=str(e)
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ def test_dograh_v2_compiles_to_effective_managed_pipeline_with_embeddings():
|
|||
assert effective.stt.provider == "dograh"
|
||||
assert effective.stt.language == "multi"
|
||||
assert effective.embeddings.provider == "dograh"
|
||||
assert effective.embeddings.model == "default"
|
||||
assert effective.embeddings.model == "dograh_embedding_v1"
|
||||
assert effective.managed_service_version == 2
|
||||
|
||||
|
||||
|
|
|
|||
162
api/tests/test_dograh_embedding_service.py
Normal file
162
api/tests/test_dograh_embedding_service.py
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
"""Tests for the Dograh-managed embedding service and its correlation resolver."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.gen_ai.embedding.dograh_service import DograhEmbeddingService
|
||||
from api.services.gen_ai.embedding.factory import resolve_embedding_correlation_id
|
||||
|
||||
|
||||
def _service_with_fake_client(correlation_id):
|
||||
service = DograhEmbeddingService(
|
||||
db_client=None,
|
||||
api_key="sk-test",
|
||||
model_id="text-embedding-3-small",
|
||||
base_url=None,
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
create = AsyncMock(
|
||||
return_value=SimpleNamespace(data=[SimpleNamespace(embedding=[0.1, 0.2])])
|
||||
)
|
||||
service.client = SimpleNamespace(embeddings=SimpleNamespace(create=create))
|
||||
return service, create
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dograh_embedding_forwards_v2_protocol_when_correlation_present():
|
||||
service, create = _service_with_fake_client("corr-123")
|
||||
|
||||
await service.embed_texts(["hello"])
|
||||
|
||||
create.assert_awaited_once()
|
||||
kwargs = create.await_args.kwargs
|
||||
assert kwargs["input"] == ["hello"]
|
||||
assert kwargs["model"] == "text-embedding-3-small"
|
||||
assert kwargs["extra_body"] == {
|
||||
"metadata": {
|
||||
"correlation_id": "corr-123",
|
||||
"mps_billing_version": "2",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dograh_embedding_sends_plain_without_correlation():
|
||||
service, create = _service_with_fake_client(None)
|
||||
|
||||
await service.embed_texts(["hello"])
|
||||
|
||||
create.assert_awaited_once()
|
||||
# No correlation id (e.g. a v1 org) → no MPS metadata; MPS accepts plain calls.
|
||||
assert "extra_body" not in create.await_args.kwargs
|
||||
|
||||
|
||||
def _fake_mps_client(*, status_return=None, minted="minted"):
|
||||
return SimpleNamespace(
|
||||
get_billing_account_status=AsyncMock(return_value=status_return),
|
||||
create_correlation_id=AsyncMock(return_value={"correlation_id": minted}),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_correlation_oss_mints_directly(monkeypatch):
|
||||
fake = _fake_mps_client()
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.mps_service_key_client", fake
|
||||
)
|
||||
monkeypatch.setattr("api.constants.DEPLOYMENT_MODE", "oss")
|
||||
|
||||
result = await resolve_embedding_correlation_id(
|
||||
organization_id=None, service_key="sk-mps"
|
||||
)
|
||||
|
||||
assert result == "minted"
|
||||
fake.create_correlation_id.assert_awaited_once_with(service_key="sk-mps")
|
||||
fake.get_billing_account_status.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_correlation_hosted_v2_mints(monkeypatch):
|
||||
fake = _fake_mps_client(status_return={"billing_mode": "v2"})
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.mps_service_key_client", fake
|
||||
)
|
||||
monkeypatch.setattr("api.constants.DEPLOYMENT_MODE", "hosted")
|
||||
|
||||
result = await resolve_embedding_correlation_id(
|
||||
organization_id=42, service_key="sk-mps", created_by="user-1"
|
||||
)
|
||||
|
||||
assert result == "minted"
|
||||
fake.get_billing_account_status.assert_awaited_once_with(42, created_by="user-1")
|
||||
fake.create_correlation_id.assert_awaited_once_with(service_key="sk-mps")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_correlation_hosted_v1_returns_none_without_minting(monkeypatch):
|
||||
fake = _fake_mps_client(status_return={"billing_mode": "v1"})
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.mps_service_key_client", fake
|
||||
)
|
||||
monkeypatch.setattr("api.constants.DEPLOYMENT_MODE", "hosted")
|
||||
|
||||
result = await resolve_embedding_correlation_id(
|
||||
organization_id=42, service_key="sk-mps"
|
||||
)
|
||||
|
||||
assert result is None
|
||||
fake.create_correlation_id.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_correlation_hosted_no_account_returns_none(monkeypatch):
|
||||
fake = _fake_mps_client(status_return=None)
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.mps_service_key_client", fake
|
||||
)
|
||||
monkeypatch.setattr("api.constants.DEPLOYMENT_MODE", "hosted")
|
||||
|
||||
result = await resolve_embedding_correlation_id(
|
||||
organization_id=42, service_key="sk-mps"
|
||||
)
|
||||
|
||||
assert result is None
|
||||
fake.create_correlation_id.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_correlation_no_service_key_returns_none(monkeypatch):
|
||||
fake = _fake_mps_client(status_return={"billing_mode": "v2"})
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.mps_service_key_client", fake
|
||||
)
|
||||
monkeypatch.setattr("api.constants.DEPLOYMENT_MODE", "hosted")
|
||||
|
||||
result = await resolve_embedding_correlation_id(
|
||||
organization_id=42, service_key=None
|
||||
)
|
||||
|
||||
assert result is None
|
||||
fake.get_billing_account_status.assert_not_awaited()
|
||||
fake.create_correlation_id.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_correlation_swallows_errors(monkeypatch):
|
||||
fake = SimpleNamespace(
|
||||
get_billing_account_status=AsyncMock(side_effect=RuntimeError("mps down")),
|
||||
create_correlation_id=AsyncMock(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.mps_service_key_client", fake
|
||||
)
|
||||
monkeypatch.setattr("api.constants.DEPLOYMENT_MODE", "hosted")
|
||||
|
||||
# A transient MPS failure must not break embeddings — fall back to no protocol.
|
||||
result = await resolve_embedding_correlation_id(
|
||||
organization_id=42, service_key="sk-mps"
|
||||
)
|
||||
|
||||
assert result is None
|
||||
26
api/tests/test_knowledge_base_processing_embeddings.py
Normal file
26
api/tests/test_knowledge_base_processing_embeddings.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
import pytest
|
||||
|
||||
from api.tasks.knowledge_base_processing import _embed_texts_in_batches
|
||||
|
||||
|
||||
class FakeEmbeddingService:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
async def embed_texts(self, texts):
|
||||
self.calls.append(list(texts))
|
||||
return [[float(len(text))] for text in texts]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embed_texts_in_batches_preserves_order():
|
||||
service = FakeEmbeddingService()
|
||||
|
||||
embeddings = await _embed_texts_in_batches(
|
||||
service,
|
||||
["a", "bb", "ccc", "dddd", "eeeee"],
|
||||
batch_size=2,
|
||||
)
|
||||
|
||||
assert service.calls == [["a", "bb"], ["ccc", "dddd"], ["eeeee"]]
|
||||
assert embeddings == [[1.0], [2.0], [3.0], [4.0], [5.0]]
|
||||
Loading…
Add table
Add a link
Reference in a new issue