mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-13 08:15:21 +02:00
Merge remote-tracking branch 'origin/main' into feat/user-onboarding
This commit is contained in:
commit
093e888ce4
148 changed files with 10908 additions and 2815 deletions
|
|
@ -35,6 +35,7 @@ import asyncio
|
|||
|
||||
from loguru import logger
|
||||
|
||||
from api.services.managed_model_services import MPS_CORRELATION_ID_CONTEXT_KEY
|
||||
from api.services.workflow import pipecat_engine_callbacks as engine_callbacks
|
||||
from api.services.workflow.mcp_tool_session import McpToolSession
|
||||
from api.services.workflow.pipecat_engine_context_composer import (
|
||||
|
|
@ -382,6 +383,9 @@ class PipecatEngine:
|
|||
embeddings_provider=self._embeddings_provider,
|
||||
embeddings_endpoint=self._embeddings_endpoint,
|
||||
embeddings_api_version=self._embeddings_api_version,
|
||||
correlation_id=self._call_context_vars.get(
|
||||
MPS_CORRELATION_ID_CONTEXT_KEY
|
||||
),
|
||||
tracing_context=self._get_otel_context(),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
import random
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import WorkflowRunModel
|
||||
from api.services.workflow.dto import QANodeData
|
||||
|
||||
|
|
@ -43,7 +42,7 @@ async def resolve_llm_config(
|
|||
async def resolve_user_llm_config(
|
||||
workflow_run: WorkflowRunModel,
|
||||
) -> tuple[str, str, str, dict]:
|
||||
"""Resolve the user's configured LLM (from UserConfiguration).
|
||||
"""Resolve the user's configured LLM (from EffectiveAIModelConfiguration).
|
||||
|
||||
Returns:
|
||||
(provider, model, api_key, service_kwargs) tuple
|
||||
|
|
@ -54,7 +53,27 @@ async def resolve_user_llm_config(
|
|||
|
||||
llm_config: dict = {}
|
||||
if user_id:
|
||||
user_configuration = await db_client.get_user_configurations(user_id)
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_effective_ai_model_configuration_for_workflow,
|
||||
)
|
||||
|
||||
workflow_configurations = {}
|
||||
if workflow_run.definition:
|
||||
workflow_configurations = (
|
||||
workflow_run.definition.workflow_configurations or {}
|
||||
)
|
||||
elif workflow_run.workflow:
|
||||
workflow_configurations = (
|
||||
workflow_run.workflow.workflow_configurations or {}
|
||||
)
|
||||
|
||||
user_configuration = await get_effective_ai_model_configuration_for_workflow(
|
||||
user_id=user_id,
|
||||
organization_id=workflow_run.workflow.organization_id
|
||||
if workflow_run.workflow
|
||||
else None,
|
||||
workflow_configurations=workflow_configurations,
|
||||
)
|
||||
llm_config = user_configuration.model_dump(exclude_none=True).get("llm", {})
|
||||
|
||||
provider = llm_config.get("provider", "openai")
|
||||
|
|
|
|||
41
api/services/workflow/run_usage_response.py
Normal file
41
api/services/workflow/run_usage_response.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
"""Format workflow run usage for public API responses."""
|
||||
|
||||
|
||||
def format_public_usage_info(usage_info: dict | None) -> dict | None:
|
||||
if not usage_info:
|
||||
return None
|
||||
|
||||
return {
|
||||
"llm": usage_info.get("llm") or {},
|
||||
"tts": usage_info.get("tts") or {},
|
||||
"stt": usage_info.get("stt") or {},
|
||||
"call_duration_seconds": usage_info.get("call_duration_seconds"),
|
||||
}
|
||||
|
||||
|
||||
def format_public_cost_info(
|
||||
cost_info: dict | None, usage_info: dict | None
|
||||
) -> dict | None:
|
||||
"""Return the legacy response shape without doing local cost accounting."""
|
||||
duration = None
|
||||
if usage_info and usage_info.get("call_duration_seconds") is not None:
|
||||
duration = int(round(usage_info.get("call_duration_seconds") or 0))
|
||||
elif cost_info and cost_info.get("call_duration_seconds") is not None:
|
||||
duration = int(round(cost_info.get("call_duration_seconds") or 0))
|
||||
|
||||
dograh_token_usage = 0
|
||||
if cost_info:
|
||||
if "dograh_token_usage" in cost_info:
|
||||
dograh_token_usage = cost_info.get("dograh_token_usage") or 0
|
||||
elif "total_cost_usd" in cost_info:
|
||||
dograh_token_usage = round(
|
||||
float(cost_info.get("total_cost_usd", 0)) * 100, 2
|
||||
)
|
||||
|
||||
if duration is None and dograh_token_usage == 0:
|
||||
return None
|
||||
|
||||
return {
|
||||
"dograh_token_usage": dograh_token_usage,
|
||||
"call_duration_seconds": duration,
|
||||
}
|
||||
|
|
@ -32,7 +32,6 @@ from pipecat.utils.run_context import set_current_org_id
|
|||
|
||||
from api.db import db_client
|
||||
from api.enums import WorkflowRunMode, WorkflowRunState
|
||||
from api.services.configuration.resolve import resolve_effective_config
|
||||
from api.services.pipecat.audio_config import create_audio_config
|
||||
from api.services.pipecat.pipeline_builder import create_pipeline_task
|
||||
from api.services.pipecat.pipeline_metrics_aggregator import (
|
||||
|
|
@ -410,14 +409,31 @@ async def execute_text_chat_pending_turn(
|
|||
run_definition = workflow_run.definition
|
||||
run_configs = run_definition.workflow_configurations or {}
|
||||
|
||||
user_config = await db_client.get_user_configurations(workflow_run.workflow.user.id)
|
||||
user_config = resolve_effective_config(
|
||||
user_config, run_configs.get("model_overrides")
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_effective_ai_model_configuration_for_workflow,
|
||||
)
|
||||
|
||||
user_config = await get_effective_ai_model_configuration_for_workflow(
|
||||
user_id=workflow_run.workflow.user.id,
|
||||
organization_id=workflow.organization_id,
|
||||
workflow_configurations=run_configs,
|
||||
)
|
||||
if user_config.llm is None:
|
||||
raise ValueError("Text chat requires an LLM configuration")
|
||||
|
||||
llm = create_llm_service(user_config)
|
||||
from api.services.managed_model_services import (
|
||||
MPS_CORRELATION_ID_CONTEXT_KEY,
|
||||
ensure_mps_correlation_id,
|
||||
)
|
||||
|
||||
base_initial_context = dict(workflow_run.initial_context or {})
|
||||
mps_correlation_id = await ensure_mps_correlation_id(
|
||||
ai_model_config=user_config,
|
||||
workflow_run_id=workflow_run_id,
|
||||
initial_context=base_initial_context,
|
||||
)
|
||||
|
||||
llm = create_llm_service(user_config, correlation_id=mps_correlation_id)
|
||||
inference_llm = llm
|
||||
|
||||
runtime_configuration = {
|
||||
|
|
@ -425,9 +441,15 @@ async def execute_text_chat_pending_turn(
|
|||
"llm_model": user_config.llm.model,
|
||||
}
|
||||
initial_context = {
|
||||
**(workflow_run.initial_context or {}),
|
||||
**base_initial_context,
|
||||
"runtime_configuration": runtime_configuration,
|
||||
}
|
||||
if mps_correlation_id:
|
||||
initial_context[MPS_CORRELATION_ID_CONTEXT_KEY] = mps_correlation_id
|
||||
await db_client.update_workflow_run(
|
||||
workflow_run_id,
|
||||
initial_context=initial_context,
|
||||
)
|
||||
|
||||
workflow_graph = WorkflowGraph(
|
||||
ReactFlowDTO.model_validate(run_definition.workflow_json)
|
||||
|
|
@ -466,9 +488,17 @@ async def execute_text_chat_pending_turn(
|
|||
embeddings_model = None
|
||||
embeddings_base_url = None
|
||||
if user_config.embeddings:
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
apply_managed_embeddings_base_url,
|
||||
)
|
||||
|
||||
embeddings_api_key = user_config.embeddings.api_key
|
||||
embeddings_model = user_config.embeddings.model
|
||||
embeddings_base_url = getattr(user_config.embeddings, "base_url", None)
|
||||
embeddings_provider = getattr(user_config.embeddings, "provider", None)
|
||||
embeddings_base_url = apply_managed_embeddings_base_url(
|
||||
provider=embeddings_provider,
|
||||
base_url=getattr(user_config.embeddings, "base_url", None),
|
||||
)
|
||||
|
||||
has_recordings = await db_client.has_active_recordings(workflow.organization_id)
|
||||
context_compaction_enabled = (workflow.workflow_configurations or {}).get(
|
||||
|
|
@ -606,8 +636,10 @@ async def execute_text_chat_pending_turn(
|
|||
"Transportless text chat pipeline failed while closing run {}",
|
||||
workflow_run_id,
|
||||
)
|
||||
await engine.close_mcp_sessions()
|
||||
await engine.cleanup()
|
||||
raise
|
||||
await engine.close_mcp_sessions()
|
||||
await engine.cleanup()
|
||||
|
||||
gathered_context = await engine.get_gathered_context()
|
||||
|
|
|
|||
|
|
@ -4,17 +4,11 @@ from datetime import UTC, datetime
|
|||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import WorkflowRunTextSessionModel
|
||||
from api.db.workflow_run_text_session_client import (
|
||||
WorkflowRunTextSessionRevisionConflictError,
|
||||
)
|
||||
from api.services.pricing.workflow_run_cost import (
|
||||
apply_usage_delta_to_organization,
|
||||
build_workflow_run_cost_info,
|
||||
)
|
||||
from api.services.workflow.text_chat_logs import (
|
||||
build_text_chat_realtime_feedback_events,
|
||||
)
|
||||
|
|
@ -261,20 +255,6 @@ async def execute_pending_text_chat_turn(
|
|||
state=execution.state,
|
||||
is_completed=execution.is_completed,
|
||||
)
|
||||
workflow_run = await db_client.get_workflow_run_by_id(run_id)
|
||||
if workflow_run:
|
||||
try:
|
||||
# Apply the per-turn delta so org usage tracks cumulative run cost
|
||||
# without replaying the full session totals on every turn.
|
||||
await apply_usage_delta_to_organization(workflow_run, execution.usage)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to update organization usage for text chat run {run_id}: {e}"
|
||||
)
|
||||
|
||||
cost_info = await build_workflow_run_cost_info(workflow_run)
|
||||
if cost_info is not None:
|
||||
await db_client.update_workflow_run(run_id, cost_info=cost_info)
|
||||
|
||||
return await _reload_text_chat_session(run_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ async def retrieve_from_knowledge_base(
|
|||
embeddings_provider: Optional[str] = None,
|
||||
embeddings_endpoint: Optional[str] = None,
|
||||
embeddings_api_version: Optional[str] = None,
|
||||
correlation_id: Optional[str] = None,
|
||||
tracing_context=None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Retrieve relevant information from the knowledge base using vector similarity search.
|
||||
|
|
@ -75,6 +76,7 @@ async def retrieve_from_knowledge_base(
|
|||
embeddings_provider,
|
||||
embeddings_endpoint,
|
||||
embeddings_api_version,
|
||||
correlation_id,
|
||||
)
|
||||
|
||||
# Create span with parent context
|
||||
|
|
@ -115,6 +117,7 @@ async def retrieve_from_knowledge_base(
|
|||
embeddings_provider,
|
||||
embeddings_endpoint,
|
||||
embeddings_api_version,
|
||||
correlation_id,
|
||||
)
|
||||
|
||||
# Add result metadata to span
|
||||
|
|
@ -192,6 +195,7 @@ async def retrieve_from_knowledge_base(
|
|||
embeddings_provider,
|
||||
embeddings_endpoint,
|
||||
embeddings_api_version,
|
||||
correlation_id,
|
||||
)
|
||||
else:
|
||||
# Tracing is disabled - perform retrieval without tracing
|
||||
|
|
@ -206,6 +210,7 @@ async def retrieve_from_knowledge_base(
|
|||
embeddings_provider,
|
||||
embeddings_endpoint,
|
||||
embeddings_api_version,
|
||||
correlation_id,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -220,6 +225,7 @@ async def _perform_retrieval(
|
|||
embeddings_provider: Optional[str] = None,
|
||||
embeddings_endpoint: Optional[str] = None,
|
||||
embeddings_api_version: Optional[str] = None,
|
||||
correlation_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Internal function to perform the actual retrieval operation.
|
||||
|
||||
|
|
@ -272,11 +278,20 @@ async def _perform_retrieval(
|
|||
api_version=embeddings_api_version or "2024-02-15-preview",
|
||||
)
|
||||
else:
|
||||
default_headers = None
|
||||
if (
|
||||
embeddings_provider == ServiceProviders.DOGRAH.value
|
||||
and correlation_id
|
||||
):
|
||||
default_headers = {
|
||||
"X-Dograh-Correlation-Id": correlation_id,
|
||||
}
|
||||
embedding_service = OpenAIEmbeddingService(
|
||||
db_client=db_client,
|
||||
api_key=embeddings_api_key,
|
||||
model_id=embeddings_model or "text-embedding-3-small",
|
||||
base_url=embeddings_base_url,
|
||||
default_headers=default_headers,
|
||||
)
|
||||
|
||||
results = await embedding_service.search_similar_chunks(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue