mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-13 08:15:21 +02:00
feat: use mps generated correlation ID
This commit is contained in:
parent
91ac460799
commit
3336c6e794
30 changed files with 453 additions and 89 deletions
|
|
@ -19,7 +19,7 @@ from api.db.models import (
|
|||
WorkflowRunModel,
|
||||
)
|
||||
from api.enums import OrganizationConfigurationKey
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
|
||||
|
||||
class OrganizationUsageClient(BaseDBClient):
|
||||
|
|
@ -473,11 +473,11 @@ class OrganizationUsageClient(BaseDBClient):
|
|||
)
|
||||
config_obj = config_result.scalar_one_or_none()
|
||||
if config_obj and config_obj.configuration:
|
||||
user_config = EffectiveAIModelConfiguration.model_validate(
|
||||
effective_config = EffectiveAIModelConfiguration.model_validate(
|
||||
config_obj.configuration
|
||||
)
|
||||
if user_config.timezone and user_timezone == "UTC":
|
||||
user_timezone = user_config.timezone
|
||||
if effective_config.timezone and user_timezone == "UTC":
|
||||
user_timezone = effective_config.timezone
|
||||
|
||||
# Validate timezone string
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from sqlalchemy.future import select
|
|||
|
||||
from api.db.base_client import BaseDBClient
|
||||
from api.db.models import UserConfigurationModel, UserModel
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
|
||||
|
||||
class UserClient(BaseDBClient):
|
||||
|
|
|
|||
|
|
@ -384,7 +384,7 @@ async def search_chunks(
|
|||
user_id=user.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
user_config = resolved_config.effective
|
||||
effective_config = resolved_config.effective
|
||||
embeddings_api_key = None
|
||||
embeddings_model = None
|
||||
embeddings_provider = None
|
||||
|
|
@ -392,17 +392,17 @@ async def search_chunks(
|
|||
embeddings_endpoint = None
|
||||
embeddings_api_version = None
|
||||
|
||||
if user_config.embeddings:
|
||||
embeddings_api_key = user_config.embeddings.api_key
|
||||
embeddings_model = user_config.embeddings.model
|
||||
embeddings_provider = getattr(user_config.embeddings, "provider", None)
|
||||
embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None)
|
||||
if effective_config.embeddings:
|
||||
embeddings_api_key = effective_config.embeddings.api_key
|
||||
embeddings_model = effective_config.embeddings.model
|
||||
embeddings_provider = getattr(effective_config.embeddings, "provider", None)
|
||||
embeddings_endpoint = getattr(effective_config.embeddings, "endpoint", None)
|
||||
embeddings_base_url = apply_managed_embeddings_base_url(
|
||||
provider=embeddings_provider,
|
||||
base_url=getattr(user_config.embeddings, "base_url", None),
|
||||
base_url=getattr(effective_config.embeddings, "base_url", None),
|
||||
)
|
||||
embeddings_api_version = getattr(
|
||||
user_config.embeddings, "api_version", None
|
||||
effective_config.embeddings, "api_version", None
|
||||
)
|
||||
|
||||
# Initialize embedding service based on provider
|
||||
|
|
|
|||
|
|
@ -1053,13 +1053,15 @@ async def update_workflow(
|
|||
user_id=user.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
user_config = resolved_config.effective
|
||||
effective_config = resolved_config.effective
|
||||
try:
|
||||
enriched_overrides = enrich_overrides_with_api_keys(
|
||||
workflow_configurations["model_overrides"],
|
||||
user_config,
|
||||
effective_config,
|
||||
)
|
||||
effective = resolve_effective_config(
|
||||
effective_config, enriched_overrides
|
||||
)
|
||||
effective = resolve_effective_config(user_config, enriched_overrides)
|
||||
if resolved_config.source == "organization_v2":
|
||||
v2_override = convert_legacy_ai_model_configuration_to_v2(effective)
|
||||
await UserConfigurationValidator().validate(
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.registry import (
|
||||
DograhEmbeddingsConfiguration,
|
||||
DograhLLMService,
|
||||
|
|
@ -23,6 +23,29 @@ DOGRAH_DEFAULT_VOICE = "default"
|
|||
DOGRAH_DEFAULT_LANGUAGE = "multi"
|
||||
|
||||
|
||||
class EffectiveAIModelConfiguration(BaseModel):
|
||||
llm: LLMConfig | None = None
|
||||
stt: STTConfig | None = None
|
||||
tts: TTSConfig | None = None
|
||||
embeddings: EmbeddingsConfig | None = None
|
||||
realtime: RealtimeConfig | None = None
|
||||
is_realtime: bool = False
|
||||
managed_service_version: int | None = None
|
||||
test_phone_number: str | None = None
|
||||
timezone: str | None = None
|
||||
last_validated_at: datetime | None = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def strip_incomplete_realtime_when_disabled(cls, data):
|
||||
"""Skip realtime validation when is_realtime is False and api_key is missing."""
|
||||
if isinstance(data, dict) and not data.get("is_realtime", False):
|
||||
realtime = data.get("realtime")
|
||||
if isinstance(realtime, dict) and not realtime.get("api_key"):
|
||||
data.pop("realtime", None)
|
||||
return data
|
||||
|
||||
|
||||
class DograhManagedAIModelConfiguration(BaseModel):
|
||||
api_key: str
|
||||
voice: str = DOGRAH_DEFAULT_VOICE
|
||||
|
|
@ -160,6 +183,7 @@ def _compile_dograh_configuration(
|
|||
model="default",
|
||||
),
|
||||
is_realtime=False,
|
||||
managed_service_version=2,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,33 +0,0 @@
|
|||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from api.services.configuration.registry import (
|
||||
EmbeddingsConfig,
|
||||
LLMConfig,
|
||||
RealtimeConfig,
|
||||
STTConfig,
|
||||
TTSConfig,
|
||||
)
|
||||
|
||||
|
||||
class EffectiveAIModelConfiguration(BaseModel):
|
||||
llm: LLMConfig | None = None
|
||||
stt: STTConfig | None = None
|
||||
tts: TTSConfig | None = None
|
||||
embeddings: EmbeddingsConfig | None = None
|
||||
realtime: RealtimeConfig | None = None
|
||||
is_realtime: bool = False
|
||||
test_phone_number: str | None = None
|
||||
timezone: str | None = None
|
||||
last_validated_at: datetime | None = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def strip_incomplete_realtime_when_disabled(cls, data):
|
||||
"""Skip realtime validation when is_realtime is False and api_key is missing."""
|
||||
if isinstance(data, dict) and not data.get("is_realtime", False):
|
||||
realtime = data.get("realtime")
|
||||
if isinstance(realtime, dict) and not realtime.get("api_key"):
|
||||
data.pop("realtime", None)
|
||||
return data
|
||||
|
|
@ -9,7 +9,7 @@ from api.constants import AUTH_PROVIDER, DOGRAH_MPS_SECRET_KEY, MPS_API_URL
|
|||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.enums import PostHogEvent
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.auth.stack_auth import stackauth
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.posthog_client import capture_event
|
||||
|
|
@ -285,8 +285,8 @@ async def create_user_configuration_with_mps_key(
|
|||
"model": "default",
|
||||
},
|
||||
}
|
||||
user_config = EffectiveAIModelConfiguration(**configuration)
|
||||
return user_config
|
||||
effective_config = EffectiveAIModelConfiguration(**configuration)
|
||||
return effective_config
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to get MPS service key: {response.status_code} - {response.text}"
|
||||
|
|
|
|||
|
|
@ -21,10 +21,10 @@ from api.schemas.ai_model_configuration import (
|
|||
BYOKPipelineAIModelConfiguration,
|
||||
BYOKRealtimeAIModelConfiguration,
|
||||
DograhManagedAIModelConfiguration,
|
||||
EffectiveAIModelConfiguration,
|
||||
OrganizationAIModelConfigurationV2,
|
||||
compile_ai_model_configuration_v2,
|
||||
)
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.masking import (
|
||||
SERVICE_SECRET_FIELDS,
|
||||
contains_masked_key,
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from groq import Groq
|
|||
# from pyneuphonic import Neuphonic
|
||||
# except ImportError:
|
||||
# Neuphonic = None
|
||||
from api.schemas.user_configuration import (
|
||||
from api.schemas.ai_model_configuration import (
|
||||
EffectiveAIModelConfiguration,
|
||||
)
|
||||
from api.services.configuration.registry import ServiceConfig, ServiceProviders
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ The rules are simple:
|
|||
import copy
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.registry import ServiceConfig
|
||||
from api.services.integrations import get_node_secret_fields
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ stored, while honouring masked API keys.
|
|||
import copy
|
||||
from typing import Dict
|
||||
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.masking import (
|
||||
MODEL_OVERRIDE_FIELDS,
|
||||
SERVICE_SECRET_FIELDS,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||
|
||||
import copy
|
||||
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.registry import (
|
||||
REGISTRY,
|
||||
ServiceType,
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ class OpenAIEmbeddingService(BaseEmbeddingService):
|
|||
api_key: Optional[str] = None,
|
||||
model_id: str = DEFAULT_MODEL_ID,
|
||||
base_url: Optional[str] = None,
|
||||
default_headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""Initialize the OpenAI embedding service.
|
||||
|
||||
|
|
@ -60,6 +61,8 @@ class OpenAIEmbeddingService(BaseEmbeddingService):
|
|||
field_name="base_url",
|
||||
)
|
||||
client_kwargs["base_url"] = base_url
|
||||
if default_headers:
|
||||
client_kwargs["default_headers"] = default_headers
|
||||
self.client = AsyncOpenAI(**client_kwargs)
|
||||
logger.info(f"OpenAI embedding service initialized with model: {model_id}")
|
||||
else:
|
||||
|
|
|
|||
98
api/services/managed_model_services.py
Normal file
98
api/services/managed_model_services.py
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
|
||||
MPS_CORRELATION_ID_CONTEXT_KEY = "mps_correlation_id"
|
||||
|
||||
|
||||
def uses_managed_model_services_v2(
|
||||
ai_model_config: EffectiveAIModelConfiguration | None,
|
||||
) -> bool:
|
||||
if (
|
||||
ai_model_config is None
|
||||
or getattr(ai_model_config, "managed_service_version", None) != 2
|
||||
):
|
||||
return False
|
||||
|
||||
return any(
|
||||
_is_dograh_service(getattr(ai_model_config, section_name, None))
|
||||
for section_name in ("llm", "tts", "stt", "embeddings")
|
||||
)
|
||||
|
||||
|
||||
def get_mps_correlation_id(initial_context: dict[str, Any] | None) -> str | None:
|
||||
if not initial_context:
|
||||
return None
|
||||
correlation_id = initial_context.get(MPS_CORRELATION_ID_CONTEXT_KEY)
|
||||
if correlation_id is None:
|
||||
return None
|
||||
return str(correlation_id)
|
||||
|
||||
|
||||
async def ensure_mps_correlation_id(
|
||||
*,
|
||||
ai_model_config: EffectiveAIModelConfiguration,
|
||||
workflow_run_id: int,
|
||||
initial_context: dict[str, Any] | None,
|
||||
) -> str | None:
|
||||
existing = get_mps_correlation_id(initial_context)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
if not uses_managed_model_services_v2(ai_model_config):
|
||||
return None
|
||||
|
||||
service_key = _get_dograh_service_api_key(ai_model_config)
|
||||
if not service_key:
|
||||
raise ValueError(
|
||||
"Managed model services v2 requires a Dograh service key before the run starts."
|
||||
)
|
||||
|
||||
response = await mps_service_key_client.create_correlation_id(
|
||||
service_key=service_key,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
correlation_id = response.get("correlation_id")
|
||||
if not correlation_id:
|
||||
raise ValueError("MPS correlation-id response did not include correlation_id")
|
||||
|
||||
correlation_id = str(correlation_id)
|
||||
logger.info(
|
||||
"Minted MPS correlation id {} for workflow run {}",
|
||||
correlation_id,
|
||||
workflow_run_id,
|
||||
)
|
||||
return correlation_id
|
||||
|
||||
|
||||
def _is_dograh_service(service: Any) -> bool:
|
||||
provider = getattr(service, "provider", None)
|
||||
return (
|
||||
provider == ServiceProviders.DOGRAH or provider == ServiceProviders.DOGRAH.value
|
||||
)
|
||||
|
||||
|
||||
def _get_dograh_service_api_key(
|
||||
ai_model_config: EffectiveAIModelConfiguration,
|
||||
) -> str | None:
|
||||
for section_name in ("llm", "tts", "stt", "embeddings"):
|
||||
service = getattr(ai_model_config, section_name, None)
|
||||
if not _is_dograh_service(service):
|
||||
continue
|
||||
|
||||
if hasattr(service, "get_all_api_keys"):
|
||||
keys = service.get_all_api_keys()
|
||||
if keys:
|
||||
return keys[0]
|
||||
|
||||
api_key = getattr(service, "api_key", None)
|
||||
if isinstance(api_key, str) and api_key:
|
||||
return api_key
|
||||
|
||||
return None
|
||||
|
|
@ -353,6 +353,40 @@ class MPSServiceKeyClient:
|
|||
response=response,
|
||||
)
|
||||
|
||||
async def create_correlation_id(
|
||||
self,
|
||||
*,
|
||||
service_key: str,
|
||||
workflow_run_id: int | None = None,
|
||||
) -> dict:
|
||||
"""Mint a server-generated correlation ID for managed model services."""
|
||||
payload: dict[str, int] = {}
|
||||
if workflow_run_id is not None:
|
||||
payload["workflow_run_id"] = workflow_run_id
|
||||
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/api/v1/service-keys/correlation-id/self",
|
||||
json=payload,
|
||||
headers={
|
||||
"Authorization": f"Bearer {service_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
logger.error(
|
||||
"Failed to create correlation ID: "
|
||||
f"{response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Failed to create correlation ID: {response.text}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
async def transcribe_audio(
|
||||
self,
|
||||
audio_data: bytes,
|
||||
|
|
|
|||
|
|
@ -340,7 +340,7 @@ async def _run_pipeline(
|
|||
if workflow_run.is_completed:
|
||||
raise HTTPException(status_code=400, detail="Workflow run already completed")
|
||||
|
||||
merged_call_context_vars = workflow_run.initial_context
|
||||
merged_call_context_vars = dict(workflow_run.initial_context or {})
|
||||
# If there is some extra call_context_vars, fold them in. Persistence
|
||||
# happens once below, after runtime_configuration is also resolved.
|
||||
if call_context_vars:
|
||||
|
|
@ -398,6 +398,19 @@ async def _run_pipeline(
|
|||
else:
|
||||
user_config = resolved_user_config
|
||||
|
||||
from api.services.managed_model_services import (
|
||||
MPS_CORRELATION_ID_CONTEXT_KEY,
|
||||
ensure_mps_correlation_id,
|
||||
)
|
||||
|
||||
mps_correlation_id = await ensure_mps_correlation_id(
|
||||
ai_model_config=user_config,
|
||||
workflow_run_id=workflow_run_id,
|
||||
initial_context=merged_call_context_vars,
|
||||
)
|
||||
if mps_correlation_id:
|
||||
merged_call_context_vars[MPS_CORRELATION_ID_CONTEXT_KEY] = mps_correlation_id
|
||||
|
||||
# Detect realtime mode (speech-to-speech services like OpenAI Realtime, Gemini Live)
|
||||
is_realtime = user_config.is_realtime and user_config.realtime is not None
|
||||
|
||||
|
|
@ -409,11 +422,23 @@ async def _run_pipeline(
|
|||
# Realtime services don't implement run_inference, so create a
|
||||
# separate text LLM for variable extraction and other out-of-band
|
||||
# inference calls.
|
||||
inference_llm = create_llm_service(user_config)
|
||||
inference_llm = create_llm_service(
|
||||
user_config,
|
||||
correlation_id=mps_correlation_id,
|
||||
)
|
||||
else:
|
||||
stt = create_stt_service(user_config, audio_config, keyterms=keyterms)
|
||||
tts = create_tts_service(user_config, audio_config)
|
||||
llm = create_llm_service(user_config)
|
||||
stt = create_stt_service(
|
||||
user_config,
|
||||
audio_config,
|
||||
keyterms=keyterms,
|
||||
correlation_id=mps_correlation_id,
|
||||
)
|
||||
tts = create_tts_service(
|
||||
user_config,
|
||||
audio_config,
|
||||
correlation_id=mps_correlation_id,
|
||||
)
|
||||
llm = create_llm_service(user_config, correlation_id=mps_correlation_id)
|
||||
inference_llm = None
|
||||
|
||||
# Stamp the providers/models actually resolved for this run onto
|
||||
|
|
@ -695,7 +720,10 @@ async def _run_pipeline(
|
|||
# Create a separate LLM instance for the voicemail sub-pipeline
|
||||
# (can't share with main pipeline as it would mess up frame linking)
|
||||
if voicemail_config.get("use_workflow_llm", True):
|
||||
voicemail_llm = create_llm_service(user_config)
|
||||
voicemail_llm = create_llm_service(
|
||||
user_config,
|
||||
correlation_id=mps_correlation_id,
|
||||
)
|
||||
else:
|
||||
voicemail_llm = create_llm_service_from_provider(
|
||||
provider=voicemail_config.get("provider", "openai"),
|
||||
|
|
|
|||
|
|
@ -78,7 +78,10 @@ def _validate_runtime_service_url(url: str, field_name: str) -> None:
|
|||
|
||||
|
||||
def create_stt_service(
|
||||
user_config, audio_config: "AudioConfig", keyterms: list[str] | None = None
|
||||
user_config,
|
||||
audio_config: "AudioConfig",
|
||||
keyterms: list[str] | None = None,
|
||||
correlation_id: str | None = None,
|
||||
):
|
||||
"""Create and return appropriate STT service based on user configuration
|
||||
|
||||
|
|
@ -160,6 +163,7 @@ def create_stt_service(
|
|||
return DograhSTTService(
|
||||
base_url=base_url,
|
||||
api_key=user_config.stt.api_key,
|
||||
correlation_id=correlation_id,
|
||||
settings=DograhSTTSettings(
|
||||
model=user_config.stt.model,
|
||||
language=language,
|
||||
|
|
@ -286,7 +290,9 @@ def create_stt_service(
|
|||
)
|
||||
|
||||
|
||||
def create_tts_service(user_config, audio_config: "AudioConfig"):
|
||||
def create_tts_service(
|
||||
user_config, audio_config: "AudioConfig", correlation_id: str | None = None
|
||||
):
|
||||
"""Create and return appropriate TTS service based on user configuration
|
||||
|
||||
Args:
|
||||
|
|
@ -404,6 +410,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
return DograhTTSService(
|
||||
base_url=base_url,
|
||||
api_key=user_config.tts.api_key,
|
||||
correlation_id=correlation_id,
|
||||
settings=DograhTTSSettings(
|
||||
model=user_config.tts.model,
|
||||
voice=user_config.tts.voice,
|
||||
|
|
@ -564,6 +571,7 @@ def create_llm_service_from_provider(
|
|||
model: str,
|
||||
api_key: str | None,
|
||||
*,
|
||||
correlation_id: str | None = None,
|
||||
base_url: str | None = None,
|
||||
endpoint: str | None = None,
|
||||
aws_access_key: str | None = None,
|
||||
|
|
@ -637,6 +645,7 @@ def create_llm_service_from_provider(
|
|||
return DograhLLMService(
|
||||
base_url=f"{MPS_API_URL}/api/v1/llm",
|
||||
api_key=api_key,
|
||||
correlation_id=correlation_id,
|
||||
settings=OpenAILLMSettings(model=model),
|
||||
)
|
||||
elif provider == ServiceProviders.AWS_BEDROCK.value:
|
||||
|
|
@ -851,7 +860,7 @@ def create_realtime_llm_service(user_config, audio_config: "AudioConfig"):
|
|||
)
|
||||
|
||||
|
||||
def create_llm_service(user_config):
|
||||
def create_llm_service(user_config, correlation_id: str | None = None):
|
||||
"""Create and return appropriate LLM service based on user configuration."""
|
||||
provider = user_config.llm.provider
|
||||
model = user_config.llm.model
|
||||
|
|
@ -880,4 +889,10 @@ def create_llm_service(user_config):
|
|||
elif provider == ServiceProviders.SARVAM.value:
|
||||
kwargs["temperature"] = user_config.llm.temperature
|
||||
|
||||
return create_llm_service_from_provider(provider, model, api_key, **kwargs)
|
||||
return create_llm_service_from_provider(
|
||||
provider,
|
||||
model,
|
||||
api_key,
|
||||
correlation_id=correlation_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ import asyncio
|
|||
|
||||
from loguru import logger
|
||||
|
||||
from api.services.managed_model_services import MPS_CORRELATION_ID_CONTEXT_KEY
|
||||
from api.services.workflow import pipecat_engine_callbacks as engine_callbacks
|
||||
from api.services.workflow.mcp_tool_session import McpToolSession
|
||||
from api.services.workflow.pipecat_engine_context_composer import (
|
||||
|
|
@ -382,6 +383,9 @@ class PipecatEngine:
|
|||
embeddings_provider=self._embeddings_provider,
|
||||
embeddings_endpoint=self._embeddings_endpoint,
|
||||
embeddings_api_version=self._embeddings_api_version,
|
||||
correlation_id=self._call_context_vars.get(
|
||||
MPS_CORRELATION_ID_CONTEXT_KEY
|
||||
),
|
||||
tracing_context=self._get_otel_context(),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -421,7 +421,19 @@ async def execute_text_chat_pending_turn(
|
|||
if user_config.llm is None:
|
||||
raise ValueError("Text chat requires an LLM configuration")
|
||||
|
||||
llm = create_llm_service(user_config)
|
||||
from api.services.managed_model_services import (
|
||||
MPS_CORRELATION_ID_CONTEXT_KEY,
|
||||
ensure_mps_correlation_id,
|
||||
)
|
||||
|
||||
base_initial_context = dict(workflow_run.initial_context or {})
|
||||
mps_correlation_id = await ensure_mps_correlation_id(
|
||||
ai_model_config=user_config,
|
||||
workflow_run_id=workflow_run_id,
|
||||
initial_context=base_initial_context,
|
||||
)
|
||||
|
||||
llm = create_llm_service(user_config, correlation_id=mps_correlation_id)
|
||||
inference_llm = llm
|
||||
|
||||
runtime_configuration = {
|
||||
|
|
@ -429,9 +441,15 @@ async def execute_text_chat_pending_turn(
|
|||
"llm_model": user_config.llm.model,
|
||||
}
|
||||
initial_context = {
|
||||
**(workflow_run.initial_context or {}),
|
||||
**base_initial_context,
|
||||
"runtime_configuration": runtime_configuration,
|
||||
}
|
||||
if mps_correlation_id:
|
||||
initial_context[MPS_CORRELATION_ID_CONTEXT_KEY] = mps_correlation_id
|
||||
await db_client.update_workflow_run(
|
||||
workflow_run_id,
|
||||
initial_context=initial_context,
|
||||
)
|
||||
|
||||
workflow_graph = WorkflowGraph(
|
||||
ReactFlowDTO.model_validate(run_definition.workflow_json)
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ async def retrieve_from_knowledge_base(
|
|||
embeddings_provider: Optional[str] = None,
|
||||
embeddings_endpoint: Optional[str] = None,
|
||||
embeddings_api_version: Optional[str] = None,
|
||||
correlation_id: Optional[str] = None,
|
||||
tracing_context=None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Retrieve relevant information from the knowledge base using vector similarity search.
|
||||
|
|
@ -75,6 +76,7 @@ async def retrieve_from_knowledge_base(
|
|||
embeddings_provider,
|
||||
embeddings_endpoint,
|
||||
embeddings_api_version,
|
||||
correlation_id,
|
||||
)
|
||||
|
||||
# Create span with parent context
|
||||
|
|
@ -115,6 +117,7 @@ async def retrieve_from_knowledge_base(
|
|||
embeddings_provider,
|
||||
embeddings_endpoint,
|
||||
embeddings_api_version,
|
||||
correlation_id,
|
||||
)
|
||||
|
||||
# Add result metadata to span
|
||||
|
|
@ -192,6 +195,7 @@ async def retrieve_from_knowledge_base(
|
|||
embeddings_provider,
|
||||
embeddings_endpoint,
|
||||
embeddings_api_version,
|
||||
correlation_id,
|
||||
)
|
||||
else:
|
||||
# Tracing is disabled - perform retrieval without tracing
|
||||
|
|
@ -206,6 +210,7 @@ async def retrieve_from_knowledge_base(
|
|||
embeddings_provider,
|
||||
embeddings_endpoint,
|
||||
embeddings_api_version,
|
||||
correlation_id,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -220,6 +225,7 @@ async def _perform_retrieval(
|
|||
embeddings_provider: Optional[str] = None,
|
||||
embeddings_endpoint: Optional[str] = None,
|
||||
embeddings_api_version: Optional[str] = None,
|
||||
correlation_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Internal function to perform the actual retrieval operation.
|
||||
|
||||
|
|
@ -272,11 +278,20 @@ async def _perform_retrieval(
|
|||
api_version=embeddings_api_version or "2024-02-15-preview",
|
||||
)
|
||||
else:
|
||||
default_headers = None
|
||||
if (
|
||||
embeddings_provider == ServiceProviders.DOGRAH.value
|
||||
and correlation_id
|
||||
):
|
||||
default_headers = {
|
||||
"X-Dograh-Correlation-Id": correlation_id,
|
||||
}
|
||||
embedding_service = OpenAIEmbeddingService(
|
||||
db_client=db_client,
|
||||
api_key=embeddings_api_key,
|
||||
model_id=embeddings_model or "text-embedding-3-small",
|
||||
base_url=embeddings_base_url,
|
||||
default_headers=default_headers,
|
||||
)
|
||||
|
||||
results = await embedding_service.search_similar_chunks(
|
||||
|
|
|
|||
|
|
@ -166,18 +166,22 @@ async def process_knowledge_base_document(
|
|||
user_id=document.created_by,
|
||||
organization_id=document.organization_id,
|
||||
)
|
||||
user_config = resolved_config.effective
|
||||
if user_config.embeddings:
|
||||
embeddings_provider = getattr(user_config.embeddings, "provider", None)
|
||||
embeddings_api_key = user_config.embeddings.api_key
|
||||
embeddings_model = user_config.embeddings.model
|
||||
effective_config = resolved_config.effective
|
||||
if effective_config.embeddings:
|
||||
embeddings_provider = getattr(
|
||||
effective_config.embeddings, "provider", None
|
||||
)
|
||||
embeddings_api_key = effective_config.embeddings.api_key
|
||||
embeddings_model = effective_config.embeddings.model
|
||||
embeddings_base_url = apply_managed_embeddings_base_url(
|
||||
provider=embeddings_provider,
|
||||
base_url=getattr(user_config.embeddings, "base_url", None),
|
||||
base_url=getattr(effective_config.embeddings, "base_url", None),
|
||||
)
|
||||
embeddings_endpoint = getattr(
|
||||
effective_config.embeddings, "endpoint", None
|
||||
)
|
||||
embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None)
|
||||
embeddings_api_version = getattr(
|
||||
user_config.embeddings, "api_version", None
|
||||
effective_config.embeddings, "api_version", None
|
||||
)
|
||||
logger.info(
|
||||
f"Using user embeddings config: provider={embeddings_provider}, "
|
||||
|
|
|
|||
|
|
@ -203,7 +203,7 @@ async def create_workflow_run_rows(
|
|||
Returns:
|
||||
Tuple of (workflow_run, user, workflow).
|
||||
"""
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
|
||||
org = OrganizationModel(provider_id=f"test-org-{provider_id_suffix}")
|
||||
async_session.add(org)
|
||||
|
|
|
|||
|
|
@ -3,10 +3,10 @@ from pydantic import ValidationError
|
|||
|
||||
from api.schemas.ai_model_configuration import (
|
||||
DograhManagedAIModelConfiguration,
|
||||
EffectiveAIModelConfiguration,
|
||||
OrganizationAIModelConfigurationV2,
|
||||
compile_ai_model_configuration_v2,
|
||||
)
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY,
|
||||
check_for_masked_keys_in_ai_model_configuration_v2,
|
||||
|
|
@ -49,6 +49,7 @@ def test_dograh_v2_compiles_to_effective_managed_pipeline_with_embeddings():
|
|||
assert effective.stt.language == "multi"
|
||||
assert effective.embeddings.provider == "dograh"
|
||||
assert effective.embeddings.model == "default"
|
||||
assert effective.managed_service_version == 2
|
||||
|
||||
|
||||
def test_dograh_v2_rejects_non_predefined_speed():
|
||||
|
|
|
|||
110
api/tests/test_dograh_managed_correlation.py
Normal file
110
api/tests/test_dograh_managed_correlation.py
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
import json
|
||||
|
||||
import pytest
|
||||
from openai._types import NOT_GIVEN as OPENAI_NOT_GIVEN
|
||||
from pipecat.frames.frames import TTSStartedFrame
|
||||
from pipecat.services.dograh.llm import DograhLLMService
|
||||
from pipecat.services.dograh.stt import DograhSTTService
|
||||
from pipecat.services.dograh.tts import DograhTTSService
|
||||
from pipecat.services.openai.base_llm import OpenAILLMSettings
|
||||
from websockets.protocol import State
|
||||
|
||||
|
||||
class _FakeWebSocket:
|
||||
def __init__(self):
|
||||
self.state = State.OPEN
|
||||
self.messages: list[dict] = []
|
||||
|
||||
async def send(self, message: str) -> None:
|
||||
self.messages.append(json.loads(message))
|
||||
|
||||
async def close(self, *args, **kwargs) -> None:
|
||||
self.state = State.CLOSED
|
||||
|
||||
|
||||
def test_dograh_llm_uses_explicit_mps_correlation_id():
|
||||
service = DograhLLMService(
|
||||
api_key="mps-secret",
|
||||
correlation_id="mps-corr-123",
|
||||
settings=OpenAILLMSettings(model="default"),
|
||||
)
|
||||
service._start_metadata = {"workflow_run_id": 99}
|
||||
|
||||
params = service.build_chat_completion_params(
|
||||
{
|
||||
"messages": [],
|
||||
"tools": OPENAI_NOT_GIVEN,
|
||||
"tool_choice": OPENAI_NOT_GIVEN,
|
||||
}
|
||||
)
|
||||
|
||||
assert params["metadata"]["correlation_id"] == "mps-corr-123"
|
||||
assert params["metadata"]["mps_billing_version"] == "2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dograh_stt_config_uses_explicit_mps_correlation_id(monkeypatch):
|
||||
fake_ws = _FakeWebSocket()
|
||||
|
||||
async def fake_connect(url, additional_headers):
|
||||
return fake_ws
|
||||
|
||||
monkeypatch.setattr(
|
||||
"pipecat.services.dograh.stt.websocket_connect",
|
||||
fake_connect,
|
||||
)
|
||||
|
||||
service = DograhSTTService(
|
||||
api_key="mps-secret",
|
||||
correlation_id="mps-corr-123",
|
||||
sample_rate=16000,
|
||||
)
|
||||
service._start_metadata = {"workflow_run_id": 99}
|
||||
|
||||
await service._connect_websocket()
|
||||
|
||||
assert fake_ws.messages[0]["type"] == "config"
|
||||
assert fake_ws.messages[0]["correlation_id"] == "mps-corr-123"
|
||||
assert fake_ws.messages[0]["mps_billing_version"] == "2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dograh_tts_messages_use_explicit_mps_correlation_id(monkeypatch):
|
||||
fake_ws = _FakeWebSocket()
|
||||
|
||||
async def fake_connect(url, additional_headers):
|
||||
return fake_ws
|
||||
|
||||
monkeypatch.setattr(
|
||||
"pipecat.services.dograh.tts.websocket_connect",
|
||||
fake_connect,
|
||||
)
|
||||
|
||||
service = DograhTTSService(
|
||||
api_key="mps-secret",
|
||||
correlation_id="mps-corr-123",
|
||||
sample_rate=24000,
|
||||
)
|
||||
service._start_metadata = {"workflow_run_id": 99}
|
||||
|
||||
await service._connect_websocket()
|
||||
assert fake_ws.messages[0]["type"] == "config"
|
||||
assert fake_ws.messages[0]["correlation_id"] == "mps-corr-123"
|
||||
assert fake_ws.messages[0]["mps_billing_version"] == "2"
|
||||
|
||||
async def _noop(*args, **kwargs):
|
||||
return None
|
||||
|
||||
service.audio_context_available = lambda context_id: False
|
||||
service.create_audio_context = _noop
|
||||
service.start_ttfb_metrics = _noop
|
||||
service.start_tts_usage_metrics = _noop
|
||||
|
||||
frames = []
|
||||
async for frame in service.run_tts("hello", "ctx-1"):
|
||||
frames.append(frame)
|
||||
|
||||
assert isinstance(frames[0], TTSStartedFrame)
|
||||
assert fake_ws.messages[1]["type"] == "create_context"
|
||||
assert fake_ws.messages[1]["correlation_id"] == "mps-corr-123"
|
||||
assert fake_ws.messages[1]["mps_billing_version"] == "2"
|
||||
|
|
@ -7,7 +7,7 @@ from pipecat.processors.aggregators.llm_context import LLMContext
|
|||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.xai.realtime import events
|
||||
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.registry import GrokRealtimeLLMConfiguration
|
||||
from api.services.pipecat.realtime.grok_realtime import (
|
||||
DograhGrokRealtimeLLMService,
|
||||
|
|
@ -120,7 +120,7 @@ async def test_completed_input_transcription_is_broadcast_as_finalized():
|
|||
|
||||
|
||||
def test_factory_creates_dograh_grok_realtime_service():
|
||||
user_config = EffectiveAIModelConfiguration(
|
||||
effective_config = EffectiveAIModelConfiguration(
|
||||
is_realtime=True,
|
||||
realtime=GrokRealtimeLLMConfiguration(
|
||||
provider="grok_realtime",
|
||||
|
|
@ -131,7 +131,7 @@ def test_factory_creates_dograh_grok_realtime_service():
|
|||
)
|
||||
|
||||
service = create_realtime_llm_service(
|
||||
user_config,
|
||||
effective_config,
|
||||
audio_config=SimpleNamespace(),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from fastapi import FastAPI
|
|||
from fastapi.testclient import TestClient
|
||||
|
||||
from api.routes.user import router
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.configuration.masking import mask_key
|
||||
from api.services.configuration.registry import (
|
||||
|
|
|
|||
|
|
@ -87,3 +87,44 @@ async def test_check_service_key_usage_uses_bearer_self_usage(monkeypatch):
|
|||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_correlation_id_uses_bearer_auth(monkeypatch):
|
||||
calls = []
|
||||
|
||||
class FakeAsyncClient:
|
||||
def __init__(self, timeout):
|
||||
self.timeout = timeout
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
async def post(self, url, json, headers):
|
||||
calls.append(("POST", url, json, headers))
|
||||
return _Response(200, {"correlation_id": "mps-corr-123"})
|
||||
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
|
||||
)
|
||||
|
||||
client = MPSServiceKeyClient()
|
||||
|
||||
assert await client.create_correlation_id(
|
||||
service_key="mps_sk_paid",
|
||||
workflow_run_id=42,
|
||||
) == {"correlation_id": "mps-corr-123"}
|
||||
assert calls == [
|
||||
(
|
||||
"POST",
|
||||
f"{client.base_url}/api/v1/service-keys/correlation-id/self",
|
||||
{"workflow_run_id": 42},
|
||||
{
|
||||
"Authorization": "Bearer mps_sk_paid",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
]
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ Module under test: api.services.configuration.resolve
|
|||
|
||||
import pytest
|
||||
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.masking import (
|
||||
contains_masked_key,
|
||||
mask_workflow_configurations,
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from pipecat.processors.frame_processor import FrameDirection
|
|||
from websockets.exceptions import ConnectionClosedError
|
||||
from websockets.frames import Close
|
||||
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.registry import UltravoxRealtimeLLMConfiguration
|
||||
from api.services.pipecat.realtime.ultravox_realtime import (
|
||||
_RESUMPTION_USER_MESSAGE,
|
||||
|
|
@ -430,7 +430,7 @@ async def test_receive_messages_reports_unexpected_websocket_close():
|
|||
|
||||
|
||||
def test_factory_creates_dograh_ultravox_realtime_service():
|
||||
user_config = EffectiveAIModelConfiguration(
|
||||
effective_config = EffectiveAIModelConfiguration(
|
||||
is_realtime=True,
|
||||
realtime=UltravoxRealtimeLLMConfiguration(
|
||||
provider="ultravox_realtime",
|
||||
|
|
@ -441,7 +441,7 @@ def test_factory_creates_dograh_ultravox_realtime_service():
|
|||
)
|
||||
|
||||
service = create_realtime_llm_service(
|
||||
user_config,
|
||||
effective_config,
|
||||
audio_config=SimpleNamespace(),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from unittest.mock import AsyncMock, patch
|
|||
import pytest
|
||||
|
||||
from api.db.models import OrganizationModel, UserModel
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.tests.integrations._run_pipeline_helpers import USER_CONFIGURATION
|
||||
from pipecat.tests import MockLLMService
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue