mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-13 08:15:21 +02:00
98 lines
2.9 KiB
Python
98 lines
2.9 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
|
|
from loguru import logger
|
|
|
|
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
|
from api.services.configuration.registry import ServiceProviders
|
|
from api.services.mps_service_key_client import mps_service_key_client
|
|
|
|
MPS_CORRELATION_ID_CONTEXT_KEY = "mps_correlation_id"
|
|
|
|
|
|
def uses_managed_model_services_v2(
|
|
ai_model_config: EffectiveAIModelConfiguration | None,
|
|
) -> bool:
|
|
if (
|
|
ai_model_config is None
|
|
or getattr(ai_model_config, "managed_service_version", None) != 2
|
|
):
|
|
return False
|
|
|
|
return any(
|
|
_is_dograh_service(getattr(ai_model_config, section_name, None))
|
|
for section_name in ("llm", "tts", "stt", "embeddings")
|
|
)
|
|
|
|
|
|
def get_mps_correlation_id(initial_context: dict[str, Any] | None) -> str | None:
|
|
if not initial_context:
|
|
return None
|
|
correlation_id = initial_context.get(MPS_CORRELATION_ID_CONTEXT_KEY)
|
|
if correlation_id is None:
|
|
return None
|
|
return str(correlation_id)
|
|
|
|
|
|
async def ensure_mps_correlation_id(
|
|
*,
|
|
ai_model_config: EffectiveAIModelConfiguration,
|
|
workflow_run_id: int,
|
|
initial_context: dict[str, Any] | None,
|
|
) -> str | None:
|
|
existing = get_mps_correlation_id(initial_context)
|
|
if existing:
|
|
return existing
|
|
|
|
if not uses_managed_model_services_v2(ai_model_config):
|
|
return None
|
|
|
|
service_key = _get_dograh_service_api_key(ai_model_config)
|
|
if not service_key:
|
|
raise ValueError(
|
|
"Managed model services v2 requires a Dograh service key before the run starts."
|
|
)
|
|
|
|
response = await mps_service_key_client.create_correlation_id(
|
|
service_key=service_key,
|
|
workflow_run_id=workflow_run_id,
|
|
)
|
|
correlation_id = response.get("correlation_id")
|
|
if not correlation_id:
|
|
raise ValueError("MPS correlation-id response did not include correlation_id")
|
|
|
|
correlation_id = str(correlation_id)
|
|
logger.info(
|
|
"Minted MPS correlation id {} for workflow run {}",
|
|
correlation_id,
|
|
workflow_run_id,
|
|
)
|
|
return correlation_id
|
|
|
|
|
|
def _is_dograh_service(service: Any) -> bool:
|
|
provider = getattr(service, "provider", None)
|
|
return (
|
|
provider == ServiceProviders.DOGRAH or provider == ServiceProviders.DOGRAH.value
|
|
)
|
|
|
|
|
|
def _get_dograh_service_api_key(
|
|
ai_model_config: EffectiveAIModelConfiguration,
|
|
) -> str | None:
|
|
for section_name in ("llm", "tts", "stt", "embeddings"):
|
|
service = getattr(ai_model_config, section_name, None)
|
|
if not _is_dograh_service(service):
|
|
continue
|
|
|
|
if hasattr(service, "get_all_api_keys"):
|
|
keys = service.get_all_api_keys()
|
|
if keys:
|
|
return keys[0]
|
|
|
|
api_key = getattr(service, "api_key", None)
|
|
if isinstance(api_key, str) and api_key:
|
|
return api_key
|
|
|
|
return None
|