mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
102 lines
3.2 KiB
Python
102 lines
3.2 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from typing import Any
|
||
|
|
|
||
|
|
from api.services.configuration.registry import ServiceProviders
|
||
|
|
from api.services.integrations.base import (
|
||
|
|
IntegrationRuntimeContext,
|
||
|
|
IntegrationRuntimeSession,
|
||
|
|
)
|
||
|
|
|
||
|
|
from .collector import TunerCollector, mode_to_tuner_call_type
|
||
|
|
|
||
|
|
|
||
|
|
def _format_model_label(provider: str | None, model: str | None) -> str:
|
||
|
|
if provider and model:
|
||
|
|
return f"{provider}/{model}"
|
||
|
|
if model:
|
||
|
|
return model
|
||
|
|
return provider or ""
|
||
|
|
|
||
|
|
|
||
|
|
def _resolve_model_labels(context: IntegrationRuntimeContext) -> tuple[str, str, str]:
|
||
|
|
user_config = context.user_config
|
||
|
|
|
||
|
|
if context.is_realtime and user_config.realtime:
|
||
|
|
realtime_provider = user_config.realtime.provider
|
||
|
|
realtime_model = user_config.realtime.model
|
||
|
|
llm_model = _format_model_label(realtime_provider, realtime_model)
|
||
|
|
if realtime_provider in {
|
||
|
|
ServiceProviders.GOOGLE_REALTIME.value,
|
||
|
|
ServiceProviders.GOOGLE_VERTEX_REALTIME.value,
|
||
|
|
ServiceProviders.OPENAI_REALTIME.value,
|
||
|
|
}:
|
||
|
|
return "", llm_model, ""
|
||
|
|
return "", llm_model, ""
|
||
|
|
|
||
|
|
return (
|
||
|
|
_format_model_label(
|
||
|
|
getattr(user_config.stt, "provider", None),
|
||
|
|
getattr(user_config.stt, "model", None),
|
||
|
|
),
|
||
|
|
_format_model_label(
|
||
|
|
getattr(user_config.llm, "provider", None),
|
||
|
|
getattr(user_config.llm, "model", None),
|
||
|
|
),
|
||
|
|
_format_model_label(
|
||
|
|
getattr(user_config.tts, "provider", None),
|
||
|
|
getattr(user_config.tts, "model", None),
|
||
|
|
),
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class TunerRuntimeSession(IntegrationRuntimeSession):
|
||
|
|
name = "tuner"
|
||
|
|
|
||
|
|
def __init__(self, collector: TunerCollector) -> None:
|
||
|
|
self._collector = collector
|
||
|
|
|
||
|
|
def attach(self, task: Any) -> None:
|
||
|
|
self._collector.attach_turn_tracking_observer(task.turn_tracking_observer)
|
||
|
|
self._collector.attach_latency_observer(task.user_bot_latency_observer)
|
||
|
|
task.add_observer(self._collector)
|
||
|
|
|
||
|
|
async def on_call_finished(
|
||
|
|
self,
|
||
|
|
*,
|
||
|
|
gathered_context: dict[str, Any],
|
||
|
|
) -> dict[str, Any] | None:
|
||
|
|
self._collector.set_disconnection_reason(
|
||
|
|
gathered_context.get("call_disposition")
|
||
|
|
)
|
||
|
|
payload = self._collector.build_payload_snapshot()
|
||
|
|
if payload is None:
|
||
|
|
return None
|
||
|
|
return {"tuner_payload": payload}
|
||
|
|
|
||
|
|
|
||
|
|
def create_runtime_sessions(
|
||
|
|
context: IntegrationRuntimeContext,
|
||
|
|
) -> list[IntegrationRuntimeSession]:
|
||
|
|
tuner_nodes = [
|
||
|
|
node
|
||
|
|
for node in context.workflow_graph.nodes.values()
|
||
|
|
if node.node_type == "tuner" and getattr(node.data, "tuner_enabled", True)
|
||
|
|
]
|
||
|
|
if not tuner_nodes:
|
||
|
|
return []
|
||
|
|
|
||
|
|
asr_model, llm_model, tts_model = _resolve_model_labels(context)
|
||
|
|
|
||
|
|
collector = TunerCollector(
|
||
|
|
workflow_run_id=context.workflow_run_id,
|
||
|
|
call_type=mode_to_tuner_call_type(context.workflow_run.mode),
|
||
|
|
asr_model=asr_model,
|
||
|
|
llm_model=llm_model,
|
||
|
|
tts_model=tts_model,
|
||
|
|
agent_version=getattr(context.run_definition, "version_number", None),
|
||
|
|
)
|
||
|
|
collector.attach_context(context.context_messages_provider)
|
||
|
|
|
||
|
|
return [TunerRuntimeSession(collector)]
|