mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-28 08:49:42 +02:00
137 lines
4.8 KiB
Python
137 lines
4.8 KiB
Python
"""Factory for embedding services, including the Dograh-managed (MPS) path.
|
|
|
|
Centralizes the provider branching (Azure BYOK / Dograh-managed / OpenAI-compatible
|
|
BYOK) that was previously duplicated across document ingestion, the search route,
|
|
and the RAG tool, and resolves the MPS billing v2 protocol the same way the voice
|
|
path does: attach it only for orgs already on v2, and never create a billing
|
|
account to do so.
|
|
"""
|
|
|
|
from typing import Optional
|
|
|
|
from loguru import logger
|
|
|
|
from api.db.db_client import DBClient
|
|
|
|
from .azure_openai_service import AzureOpenAIEmbeddingService
|
|
from .base import BaseEmbeddingService
|
|
from .dograh_service import DograhEmbeddingService
|
|
from .openai_service import OpenAIEmbeddingService
|
|
|
|
DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small"
|
|
DEFAULT_AZURE_API_VERSION = "2024-02-15-preview"
|
|
|
|
|
|
async def resolve_embedding_correlation_id(
|
|
*,
|
|
organization_id: Optional[int],
|
|
service_key: Optional[str],
|
|
created_by: Optional[str] = None,
|
|
) -> Optional[str]:
|
|
"""Resolve an MPS correlation id for a managed embedding call made outside a run.
|
|
|
|
Mirrors the voice path's gating:
|
|
|
|
- OSS deployments use a pasted hosted v2 key (v2 by definition), so mint
|
|
directly via the bearer endpoint — matching ``_authorize_oss_managed_v2_correlation``.
|
|
- Hosted/SaaS: read the org's billing mode (no side effects) and mint only when
|
|
it is already v2. Minting for an already-v2 org is a no-op on the account.
|
|
|
|
Returns ``None`` when the call should be sent without the protocol; MPS accepts
|
|
un-gated embedding calls from v1 orgs. Never creates a v2 billing account.
|
|
"""
|
|
if not service_key:
|
|
return None
|
|
|
|
# Imported lazily to avoid import-time cycles between the gen_ai and service
|
|
# layers (matches the inline-import convention used elsewhere in the app).
|
|
from api.constants import DEPLOYMENT_MODE
|
|
from api.services.mps_service_key_client import mps_service_key_client
|
|
|
|
try:
|
|
if DEPLOYMENT_MODE == "oss":
|
|
minted = await mps_service_key_client.create_correlation_id(
|
|
service_key=service_key
|
|
)
|
|
return minted.get("correlation_id")
|
|
|
|
if organization_id is None:
|
|
return None
|
|
|
|
status = await mps_service_key_client.get_billing_account_status(
|
|
organization_id, created_by=created_by
|
|
)
|
|
if not status or status.get("billing_mode") != "v2":
|
|
return None
|
|
|
|
minted = await mps_service_key_client.create_correlation_id(
|
|
service_key=service_key
|
|
)
|
|
return minted.get("correlation_id")
|
|
except Exception as e:
|
|
logger.warning(
|
|
"Could not resolve MPS correlation id for managed embeddings; "
|
|
"sending without v2 protocol: {}",
|
|
e,
|
|
)
|
|
return None
|
|
|
|
|
|
async def build_embedding_service(
|
|
*,
|
|
db_client: DBClient,
|
|
provider: Optional[str],
|
|
api_key: Optional[str],
|
|
model: Optional[str],
|
|
base_url: Optional[str] = None,
|
|
endpoint: Optional[str] = None,
|
|
api_version: Optional[str] = None,
|
|
correlation_id: Optional[str] = None,
|
|
organization_id: Optional[int] = None,
|
|
created_by: Optional[str] = None,
|
|
resolve_correlation: bool = False,
|
|
) -> BaseEmbeddingService:
|
|
"""Construct the right embedding service for a provider/config.
|
|
|
|
Args:
|
|
correlation_id: A correlation id already available in context (e.g. the
|
|
running workflow's MPS correlation id). Used for the Dograh provider.
|
|
resolve_correlation: When True and no ``correlation_id`` is supplied, resolve
|
|
one for the Dograh provider via ``resolve_embedding_correlation_id``
|
|
(for calls made outside a workflow run: ingestion, manual search).
|
|
"""
|
|
from api.services.configuration.registry import ServiceProviders
|
|
|
|
model_id = model or DEFAULT_EMBEDDING_MODEL
|
|
|
|
if provider == ServiceProviders.AZURE.value and endpoint:
|
|
return AzureOpenAIEmbeddingService(
|
|
db_client=db_client,
|
|
api_key=api_key,
|
|
endpoint=endpoint,
|
|
model_id=model_id,
|
|
api_version=api_version or DEFAULT_AZURE_API_VERSION,
|
|
)
|
|
|
|
if provider == ServiceProviders.DOGRAH.value:
|
|
cid = correlation_id
|
|
if cid is None and resolve_correlation:
|
|
cid = await resolve_embedding_correlation_id(
|
|
organization_id=organization_id,
|
|
service_key=api_key,
|
|
created_by=created_by,
|
|
)
|
|
return DograhEmbeddingService(
|
|
db_client=db_client,
|
|
api_key=api_key,
|
|
model_id=model_id,
|
|
base_url=base_url,
|
|
correlation_id=cid,
|
|
)
|
|
|
|
return OpenAIEmbeddingService(
|
|
db_client=db_client,
|
|
api_key=api_key,
|
|
model_id=model_id,
|
|
base_url=base_url,
|
|
)
|