fix: enable knowledge base with Dograh config v2

This commit is contained in:
Abhishek Kumar 2026-06-25 22:21:11 +05:30
parent d675fd1fda
commit efb25a0cc5
19 changed files with 557 additions and 113 deletions

View file

@ -457,6 +457,11 @@ async def create_user_configuration_with_mps_key(
"api_key": [service_key],
"model": "default",
},
"embeddings": {
"provider": ServiceProviders.DOGRAH.value,
"api_key": [service_key],
"model": "dograh_embedding_v1",
},
}
effective_config = EffectiveAIModelConfiguration(**configuration)
return effective_config

View file

@ -316,6 +316,7 @@ def convert_legacy_ai_model_configuration_to_v2(
def dograh_embeddings_base_url() -> str:
# AsyncOpenAI appends "/embeddings"; MPS exposes that under /api/v1/llm.
return f"{MPS_API_URL}/api/v1/llm"

View file

@ -1726,7 +1726,7 @@ class AzureOpenAIEmbeddingsConfiguration(BaseEmbeddingsConfiguration):
)
DOGRAH_EMBEDDING_MODELS = ["default"]
DOGRAH_EMBEDDING_MODELS = ["dograh_embedding_v1"]
@register_embeddings
@ -1734,7 +1734,7 @@ class DograhEmbeddingsConfiguration(BaseEmbeddingsConfiguration):
model_config = DOGRAH_PROVIDER_MODEL_CONFIG
provider: Literal[ServiceProviders.DOGRAH] = ServiceProviders.DOGRAH
model: str = Field(
default="default",
default="dograh_embedding_v1",
description="Dograh-managed embedding model.",
json_schema_extra={"examples": DOGRAH_EMBEDDING_MODELS},
)

View file

@ -4,8 +4,11 @@ from .embedding import (
AzureEmbeddingAPIKeyNotConfiguredError,
AzureOpenAIEmbeddingService,
BaseEmbeddingService,
DograhEmbeddingService,
EmbeddingAPIKeyNotConfiguredError,
OpenAIEmbeddingService,
build_embedding_service,
resolve_embedding_correlation_id,
)
from .json_parser import parse_llm_json
@ -13,7 +16,10 @@ __all__ = [
"AzureEmbeddingAPIKeyNotConfiguredError",
"AzureOpenAIEmbeddingService",
"BaseEmbeddingService",
"DograhEmbeddingService",
"EmbeddingAPIKeyNotConfiguredError",
"OpenAIEmbeddingService",
"build_embedding_service",
"resolve_embedding_correlation_id",
"parse_llm_json",
]

View file

@ -5,12 +5,17 @@ from .azure_openai_service import (
AzureOpenAIEmbeddingService,
)
from .base import BaseEmbeddingService
from .dograh_service import DograhEmbeddingService
from .factory import build_embedding_service, resolve_embedding_correlation_id
from .openai_service import EmbeddingAPIKeyNotConfiguredError, OpenAIEmbeddingService
__all__ = [
"AzureEmbeddingAPIKeyNotConfiguredError",
"AzureOpenAIEmbeddingService",
"BaseEmbeddingService",
"DograhEmbeddingService",
"EmbeddingAPIKeyNotConfiguredError",
"OpenAIEmbeddingService",
"build_embedding_service",
"resolve_embedding_correlation_id",
]

View file

@ -0,0 +1,69 @@
"""Dograh-managed embedding service.
Routes embeddings through Dograh's managed proxy (MPS). This mirrors the managed
voice services (``DograhLLMService`` / ``DograhTTSService``): when a server-minted
MPS correlation id is present, it forwards the MPS billing v2 protocol
(``correlation_id`` + ``mps_billing_version``) in the request body so MPS can
authorize and attribute the call. With no correlation id (e.g. a v1 org) it
behaves like a plain OpenAI-compatible call, which MPS accepts.
Keeping this in a subclass keeps ``OpenAIEmbeddingService`` a generic
OpenAI-compatible client; only the managed path carries MPS-specific metadata,
so BYOK OpenAI/Azure requests never ship MPS fields to the real provider.
"""
from typing import Any, Dict, Optional
from api.db.db_client import DBClient
from .openai_service import DEFAULT_MODEL_ID, OpenAIEmbeddingService
# Protocol contract with MPS (see model_services
# api/services/model_service_correlations.py). Kept local to avoid coupling the
# app layer to the pipecat package, which defines its own copy for voice.
MPS_BILLING_VERSION_KEY = "mps_billing_version"
MPS_BILLING_VERSION_V2 = "2"
class DograhEmbeddingService(OpenAIEmbeddingService):
"""OpenAI-compatible embedding client pointed at Dograh's managed proxy."""
def __init__(
self,
db_client: DBClient,
api_key: Optional[str] = None,
model_id: str = DEFAULT_MODEL_ID,
base_url: Optional[str] = None,
correlation_id: Optional[str] = None,
):
"""Initialize the managed embedding service.
Args:
db_client: Database client for vector similarity search.
api_key: Dograh-managed MPS service key.
model_id: Embedding model/tier id (default: text-embedding-3-small).
base_url: MPS embeddings base URL.
correlation_id: Server-minted MPS correlation id. When set, the MPS
billing v2 protocol is forwarded with each request. When None,
requests are sent without the protocol (valid for v1 orgs).
"""
super().__init__(
db_client=db_client,
api_key=api_key,
model_id=model_id,
base_url=base_url,
)
self._correlation_id = correlation_id
def _request_kwargs(self) -> Dict[str, Any]:
"""Forward the MPS billing v2 protocol when a correlation id is present."""
if not self._correlation_id:
return {}
return {
"extra_body": {
"metadata": {
"correlation_id": self._correlation_id,
MPS_BILLING_VERSION_KEY: MPS_BILLING_VERSION_V2,
}
}
}

View file

@ -0,0 +1,137 @@
"""Factory for embedding services, including the Dograh-managed (MPS) path.
Centralizes the provider branching (Azure BYOK / Dograh-managed / OpenAI-compatible
BYOK) that was previously duplicated across document ingestion, the search route,
and the RAG tool, and resolves the MPS billing v2 protocol the same way the voice
path does: attach it only for orgs already on v2, and never create a billing
account to do so.
"""
from typing import Optional
from loguru import logger
from api.db.db_client import DBClient
from .azure_openai_service import AzureOpenAIEmbeddingService
from .base import BaseEmbeddingService
from .dograh_service import DograhEmbeddingService
from .openai_service import OpenAIEmbeddingService
DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small"
DEFAULT_AZURE_API_VERSION = "2024-02-15-preview"
async def resolve_embedding_correlation_id(
*,
organization_id: Optional[int],
service_key: Optional[str],
created_by: Optional[str] = None,
) -> Optional[str]:
"""Resolve an MPS correlation id for a managed embedding call made outside a run.
Mirrors the voice path's gating:
- OSS deployments use a pasted hosted v2 key (v2 by definition), so mint
directly via the bearer endpoint matching ``_authorize_oss_managed_v2_correlation``.
- Hosted/SaaS: read the org's billing mode (no side effects) and mint only when
it is already v2. Minting for an already-v2 org is a no-op on the account.
Returns ``None`` when the call should be sent without the protocol; MPS accepts
un-gated embedding calls from v1 orgs. Never creates a v2 billing account.
"""
if not service_key:
return None
# Imported lazily to avoid import-time cycles between the gen_ai and service
# layers (matches the inline-import convention used elsewhere in the app).
from api.constants import DEPLOYMENT_MODE
from api.services.mps_service_key_client import mps_service_key_client
try:
if DEPLOYMENT_MODE == "oss":
minted = await mps_service_key_client.create_correlation_id(
service_key=service_key
)
return minted.get("correlation_id")
if organization_id is None:
return None
status = await mps_service_key_client.get_billing_account_status(
organization_id, created_by=created_by
)
if not status or status.get("billing_mode") != "v2":
return None
minted = await mps_service_key_client.create_correlation_id(
service_key=service_key
)
return minted.get("correlation_id")
except Exception as e:
logger.warning(
"Could not resolve MPS correlation id for managed embeddings; "
"sending without v2 protocol: {}",
e,
)
return None
async def build_embedding_service(
*,
db_client: DBClient,
provider: Optional[str],
api_key: Optional[str],
model: Optional[str],
base_url: Optional[str] = None,
endpoint: Optional[str] = None,
api_version: Optional[str] = None,
correlation_id: Optional[str] = None,
organization_id: Optional[int] = None,
created_by: Optional[str] = None,
resolve_correlation: bool = False,
) -> BaseEmbeddingService:
"""Construct the right embedding service for a provider/config.
Args:
correlation_id: A correlation id already available in context (e.g. the
running workflow's MPS correlation id). Used for the Dograh provider.
resolve_correlation: When True and no ``correlation_id`` is supplied, resolve
one for the Dograh provider via ``resolve_embedding_correlation_id``
(for calls made outside a workflow run: ingestion, manual search).
"""
from api.services.configuration.registry import ServiceProviders
model_id = model or DEFAULT_EMBEDDING_MODEL
if provider == ServiceProviders.AZURE.value and endpoint:
return AzureOpenAIEmbeddingService(
db_client=db_client,
api_key=api_key,
endpoint=endpoint,
model_id=model_id,
api_version=api_version or DEFAULT_AZURE_API_VERSION,
)
if provider == ServiceProviders.DOGRAH.value:
cid = correlation_id
if cid is None and resolve_correlation:
cid = await resolve_embedding_correlation_id(
organization_id=organization_id,
service_key=api_key,
created_by=created_by,
)
return DograhEmbeddingService(
db_client=db_client,
api_key=api_key,
model_id=model_id,
base_url=base_url,
correlation_id=cid,
)
return OpenAIEmbeddingService(
db_client=db_client,
api_key=api_key,
model_id=model_id,
base_url=base_url,
)

View file

@ -85,6 +85,14 @@ class OpenAIEmbeddingService(BaseEmbeddingService):
if not self._api_key_configured or self.client is None:
raise EmbeddingAPIKeyNotConfiguredError()
def _request_kwargs(self) -> Dict[str, Any]:
"""Extra kwargs merged into every embeddings.create() call.
Override hook for subclasses (e.g. DograhEmbeddingService injects the MPS
billing protocol here). The base service adds nothing.
"""
return {}
async def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""Embed a batch of texts using OpenAI API.
@ -97,6 +105,7 @@ class OpenAIEmbeddingService(BaseEmbeddingService):
response = await self.client.embeddings.create(
input=texts,
model=self.model_id,
**self._request_kwargs(),
)
return [item.embedding for item in response.data]
except Exception as e:

View file

@ -41,9 +41,7 @@ async def _persist_amd_result_if_present(
gathered_context={"answered_by": amd_result.answered_by},
)
except Exception as exc:
logger.warning(
f"[run {workflow_run_id}] Failed to persist AMD result: {exc}"
)
logger.warning(f"[run {workflow_run_id}] Failed to persist AMD result: {exc}")
@router.post("/twiml", include_in_schema=False)

View file

@ -13,8 +13,7 @@ from loguru import logger
from opentelemetry import trace
from api.db import db_client
from api.services.configuration.registry import ServiceProviders
from api.services.gen_ai import AzureOpenAIEmbeddingService, OpenAIEmbeddingService
from api.services.gen_ai import build_embedding_service
from api.services.pipecat.tracing_config import ensure_tracing
@ -266,33 +265,19 @@ async def _perform_retrieval(
"Model Configurations > Embedding."
)
if (
embeddings_provider == ServiceProviders.AZURE.value
and embeddings_endpoint
):
embedding_service = AzureOpenAIEmbeddingService(
db_client=db_client,
api_key=embeddings_api_key,
endpoint=embeddings_endpoint,
model_id=embeddings_model or "text-embedding-3-small",
api_version=embeddings_api_version or "2024-02-15-preview",
)
else:
default_headers = None
if (
embeddings_provider == ServiceProviders.DOGRAH.value
and correlation_id
):
default_headers = {
"X-Dograh-Correlation-Id": correlation_id,
}
embedding_service = OpenAIEmbeddingService(
db_client=db_client,
api_key=embeddings_api_key,
model_id=embeddings_model or "text-embedding-3-small",
base_url=embeddings_base_url,
default_headers=default_headers,
)
# Search runs inside a workflow run: reuse the run's MPS correlation
# id (present only for v2 orgs; None otherwise → sent without the
# protocol). The Dograh-managed path forwards it via request metadata.
embedding_service = await build_embedding_service(
db_client=db_client,
provider=embeddings_provider,
api_key=embeddings_api_key,
model=embeddings_model,
base_url=embeddings_base_url,
endpoint=embeddings_endpoint,
api_version=embeddings_api_version,
correlation_id=correlation_id,
)
results = await embedding_service.search_similar_chunks(
query=query,