mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-19 08:28:10 +02:00
chore: fix tracing for text chat mode
This commit is contained in:
parent
e23cce444f
commit
08a2435ba5
31 changed files with 1753 additions and 597 deletions
|
|
@ -534,7 +534,7 @@ class PipecatEngine:
|
|||
)
|
||||
await self._update_llm_context(system_prompt, functions)
|
||||
|
||||
async def set_node(self, node_id: str):
|
||||
async def set_node(self, node_id: str, emit_transition_event: bool = True):
|
||||
"""
|
||||
Simplified set_node implementation according to v2 PRD.
|
||||
"""
|
||||
|
|
@ -557,7 +557,7 @@ class PipecatEngine:
|
|||
nodes_visited.append(node.name)
|
||||
|
||||
# Send node transition event if callback is provided
|
||||
if self._node_transition_callback:
|
||||
if emit_transition_event and self._node_transition_callback:
|
||||
try:
|
||||
await self._node_transition_callback(
|
||||
node_id,
|
||||
|
|
|
|||
|
|
@ -6,7 +6,10 @@ import re
|
|||
from loguru import logger
|
||||
|
||||
from api.db.models import WorkflowRunModel
|
||||
from api.services.pipecat.tracing_config import get_trace_url
|
||||
from api.services.pipecat.tracing_config import (
|
||||
build_remote_parent_context,
|
||||
get_trace_url,
|
||||
)
|
||||
|
||||
|
||||
def extract_trace_id(gathered_context: dict) -> str | None:
|
||||
|
|
@ -33,36 +36,12 @@ def setup_langfuse_parent_context(workflow_run: WorkflowRunModel):
|
|||
|
||||
Returns the parent context object, or None if tracing is unavailable.
|
||||
"""
|
||||
try:
|
||||
from opentelemetry.trace import (
|
||||
NonRecordingSpan,
|
||||
SpanContext,
|
||||
TraceFlags,
|
||||
set_span_in_context,
|
||||
)
|
||||
|
||||
from api.services.pipecat.tracing_config import ensure_tracing
|
||||
|
||||
if not ensure_tracing():
|
||||
return None
|
||||
|
||||
gathered_context = workflow_run.gathered_context or {}
|
||||
trace_id = extract_trace_id(gathered_context)
|
||||
if not trace_id:
|
||||
logger.debug("No trace_id found, skipping Langfuse tracing")
|
||||
return None
|
||||
|
||||
parent_span_ctx = SpanContext(
|
||||
trace_id=int(trace_id, 16),
|
||||
span_id=0x1,
|
||||
is_remote=True,
|
||||
trace_flags=TraceFlags(0x01),
|
||||
)
|
||||
return set_span_in_context(NonRecordingSpan(parent_span_ctx))
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to set up Langfuse parent context: {e}")
|
||||
gathered_context = workflow_run.gathered_context or {}
|
||||
trace_id = extract_trace_id(gathered_context)
|
||||
if not trace_id:
|
||||
logger.debug("No trace_id found, skipping Langfuse tracing")
|
||||
return None
|
||||
return build_remote_parent_context(trace_id)
|
||||
|
||||
|
||||
def add_qa_span_to_trace(
|
||||
|
|
|
|||
143
api/services/workflow/text_chat_logs.py
Normal file
143
api/services/workflow/text_chat_logs.py
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
"""Helpers for projecting text-chat session state into run-log snapshots."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from api.services.pipecat.realtime_feedback_events import (
|
||||
build_bot_text_event,
|
||||
build_function_call_end_event,
|
||||
build_function_call_start_event,
|
||||
build_node_transition_event,
|
||||
build_pipeline_error_event,
|
||||
build_user_transcription_event,
|
||||
realtime_feedback_event_sort_key,
|
||||
stamp_realtime_feedback_event,
|
||||
)
|
||||
|
||||
|
||||
def visible_text_chat_turns(session_data: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""Return the active branch of turns for the current text-chat session.
|
||||
|
||||
After a rewind, `session_data["turns"]` may still contain future turns until
|
||||
the next message is sent. Those turns are no longer part of the visible
|
||||
branch, so callers that synthesize transcript/log views should trim at
|
||||
`cursor_turn_id`.
|
||||
"""
|
||||
turns = list(session_data.get("turns") or [])
|
||||
cursor_turn_id = session_data.get("cursor_turn_id")
|
||||
if cursor_turn_id is None:
|
||||
return turns
|
||||
|
||||
for index, turn in enumerate(turns):
|
||||
if turn.get("id") == cursor_turn_id:
|
||||
return turns[: index + 1]
|
||||
|
||||
return turns
|
||||
|
||||
|
||||
def build_text_chat_realtime_feedback_events(
|
||||
session_data: dict[str, Any],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Project text-chat session state into `workflow_runs.logs` event format.
|
||||
|
||||
`workflow_run_text_sessions` holds the authoritative rewindable conversation
|
||||
state. Historical run pages and QA helpers read the normalized
|
||||
`workflow_runs.logs.realtime_feedback_events` schema instead, so this helper
|
||||
rebuilds that snapshot from the currently visible branch.
|
||||
"""
|
||||
events: list[dict[str, Any]] = []
|
||||
last_emitted_node_id: str | None = None
|
||||
|
||||
for turn_index, turn in enumerate(visible_text_chat_turns(session_data)):
|
||||
turn_events = list(turn.get("events") or [])
|
||||
for event in turn_events:
|
||||
payload = dict(event.get("payload") or {})
|
||||
event_type = event.get("type")
|
||||
timestamp = event.get("created_at") or turn.get("created_at")
|
||||
|
||||
if event_type == "node_transition":
|
||||
node_id = payload.get("node_id")
|
||||
if node_id is not None and node_id == last_emitted_node_id:
|
||||
continue
|
||||
snapshot_event = stamp_realtime_feedback_event(
|
||||
build_node_transition_event(
|
||||
node_id=node_id,
|
||||
node_name=payload.get("node_name"),
|
||||
previous_node_id=payload.get("previous_node_id"),
|
||||
previous_node_name=payload.get("previous_node_name"),
|
||||
allow_interrupt=bool(payload.get("allow_interrupt", False)),
|
||||
),
|
||||
timestamp=timestamp,
|
||||
turn=turn_index,
|
||||
node_id=node_id,
|
||||
node_name=payload.get("node_name"),
|
||||
)
|
||||
if node_id is not None:
|
||||
last_emitted_node_id = node_id
|
||||
events.append(snapshot_event)
|
||||
elif event_type == "tool_call_started":
|
||||
events.append(
|
||||
stamp_realtime_feedback_event(
|
||||
build_function_call_start_event(
|
||||
function_name=payload.get("function_name"),
|
||||
tool_call_id=payload.get("tool_call_id"),
|
||||
),
|
||||
timestamp=timestamp,
|
||||
turn=turn_index,
|
||||
)
|
||||
)
|
||||
elif event_type == "tool_call_result":
|
||||
events.append(
|
||||
stamp_realtime_feedback_event(
|
||||
build_function_call_end_event(
|
||||
function_name=payload.get("function_name"),
|
||||
tool_call_id=payload.get("tool_call_id"),
|
||||
result=payload.get("result"),
|
||||
),
|
||||
timestamp=timestamp,
|
||||
turn=turn_index,
|
||||
)
|
||||
)
|
||||
elif event_type == "execution_error":
|
||||
events.append(
|
||||
stamp_realtime_feedback_event(
|
||||
build_pipeline_error_event(
|
||||
error=payload.get("message", "Execution error"),
|
||||
fatal=True,
|
||||
),
|
||||
timestamp=timestamp,
|
||||
turn=turn_index,
|
||||
)
|
||||
)
|
||||
|
||||
user_message = turn.get("user_message") or {}
|
||||
if user_message.get("text"):
|
||||
message_timestamp = user_message.get("created_at") or turn.get("created_at")
|
||||
events.append(
|
||||
stamp_realtime_feedback_event(
|
||||
build_user_transcription_event(
|
||||
text=user_message["text"],
|
||||
final=True,
|
||||
timestamp=message_timestamp,
|
||||
),
|
||||
timestamp=message_timestamp,
|
||||
turn=turn_index,
|
||||
)
|
||||
)
|
||||
|
||||
assistant_message = turn.get("assistant_message") or {}
|
||||
if assistant_message.get("text"):
|
||||
message_timestamp = assistant_message.get("created_at") or turn.get(
|
||||
"created_at"
|
||||
)
|
||||
events.append(
|
||||
stamp_realtime_feedback_event(
|
||||
build_bot_text_event(
|
||||
text=assistant_message["text"],
|
||||
timestamp=message_timestamp,
|
||||
),
|
||||
timestamp=message_timestamp,
|
||||
turn=turn_index,
|
||||
)
|
||||
)
|
||||
|
||||
return sorted(events, key=realtime_feedback_event_sort_key)
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
import hashlib
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
|
|
@ -28,6 +29,7 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
|||
LLMContextAggregatorPair,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.utils.run_context import set_current_org_id
|
||||
|
||||
from api.db import db_client
|
||||
from api.enums import WorkflowRunMode, WorkflowRunState
|
||||
|
|
@ -39,6 +41,10 @@ from api.services.pipecat.pipeline_metrics_aggregator import (
|
|||
)
|
||||
from api.services.pipecat.recording_audio_cache import create_recording_audio_fetcher
|
||||
from api.services.pipecat.service_factory import create_llm_service
|
||||
from api.services.pipecat.tracing_config import (
|
||||
build_remote_parent_context,
|
||||
get_trace_url,
|
||||
)
|
||||
from api.services.workflow.dto import ReactFlowDTO
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.workflow_graph import WorkflowGraph
|
||||
|
|
@ -49,6 +55,19 @@ TEXT_CHAT_IDLE_SETTLE_SECONDS = 0.2
|
|||
TEXT_CHAT_INTERNAL_CANCEL_REASON = "text_chat_turn_complete"
|
||||
|
||||
|
||||
def text_chat_trace_id(workflow_run_id: int) -> str:
|
||||
"""Deterministic Langfuse trace id for a text-chat session.
|
||||
|
||||
Each turn runs in its own short-lived pipeline, so there is no single
|
||||
long-running task to own the trace the way a voice call does. Deriving the
|
||||
id from the run id means every turn re-creates the *same* trace id and all
|
||||
per-turn spans land in one shared trace — without persisting extra state
|
||||
across the otherwise stateless turn requests.
|
||||
"""
|
||||
digest = hashlib.sha256(f"dograh-text-chat:{workflow_run_id}".encode()).hexdigest()
|
||||
return digest[:32]
|
||||
|
||||
|
||||
def default_text_chat_checkpoint() -> dict[str, Any]:
|
||||
return {
|
||||
"version": TEXT_CHAT_CHECKPOINT_VERSION,
|
||||
|
|
@ -379,6 +398,12 @@ async def execute_text_chat_pending_turn(
|
|||
if workflow is None:
|
||||
raise ValueError("Workflow not found for text chat execution")
|
||||
|
||||
# Stamp the async context so OTEL spans are tagged with this org and routed
|
||||
# to its Langfuse project (the voice paths do this in run_pipeline /
|
||||
# webrtc_signaling; the text path previously skipped it, so its spans never
|
||||
# reached org-specific exporters).
|
||||
set_current_org_id(workflow.organization_id)
|
||||
|
||||
run_definition = workflow_run.definition
|
||||
run_configs = run_definition.workflow_configurations or {}
|
||||
|
||||
|
|
@ -482,6 +507,17 @@ async def execute_text_chat_pending_turn(
|
|||
audio_config = create_audio_config(WorkflowRunMode.SMALLWEBRTC.value)
|
||||
pipeline_metrics_aggregator = PipelineMetricsAggregator()
|
||||
|
||||
# Stitch every per-turn pipeline of this session into one Langfuse trace by
|
||||
# handing each task the same remote parent context (derived from the run id).
|
||||
trace_id = text_chat_trace_id(workflow_run_id)
|
||||
conversation_parent_context = build_remote_parent_context(trace_id)
|
||||
# The stitched trace has no real root span (each per-turn conversation span
|
||||
# hangs off a synthetic remote parent), so Langfuse can't infer a name and
|
||||
# shows "Unnamed trace". Name it explicitly via the conversation span.
|
||||
trace_span_attributes = {
|
||||
"langfuse.trace.name": workflow_run.name or f"text-chat-{workflow_run_id}"
|
||||
}
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
llm,
|
||||
|
|
@ -490,7 +526,14 @@ async def execute_text_chat_pending_turn(
|
|||
pipeline_metrics_aggregator,
|
||||
]
|
||||
)
|
||||
task = create_pipeline_task(pipeline, workflow_run_id, audio_config)
|
||||
task = create_pipeline_task(
|
||||
pipeline,
|
||||
workflow_run_id,
|
||||
audio_config,
|
||||
conversation_parent_context=conversation_parent_context,
|
||||
conversation_type="text",
|
||||
additional_span_attributes=trace_span_attributes,
|
||||
)
|
||||
runner = PipelineRunner(handle_sigint=False, handle_sigterm=False)
|
||||
runner_task = asyncio.create_task(runner.run(task))
|
||||
|
||||
|
|
@ -511,7 +554,10 @@ async def execute_text_chat_pending_turn(
|
|||
|
||||
current_node_id = base_checkpoint.get("current_node_id")
|
||||
target_node_id = current_node_id or workflow_graph.start_node_id
|
||||
await engine.set_node(target_node_id)
|
||||
await engine.set_node(
|
||||
target_node_id,
|
||||
emit_transition_event=current_node_id is None,
|
||||
)
|
||||
|
||||
opening_marker = capture_processor.activity_count
|
||||
opening_expects_llm = pending_user_message is None and (
|
||||
|
|
@ -581,13 +627,18 @@ async def execute_text_chat_pending_turn(
|
|||
"tool_state": jsonable_encoder(base_checkpoint.get("tool_state") or {}),
|
||||
}
|
||||
|
||||
encoded_gathered_context = jsonable_encoder(gathered_context)
|
||||
trace_url = get_trace_url(trace_id, org_id=workflow.organization_id)
|
||||
if trace_url:
|
||||
encoded_gathered_context = {**encoded_gathered_context, "trace_url": trace_url}
|
||||
|
||||
return TextChatTurnExecutionResult(
|
||||
assistant_text=assistant_text,
|
||||
assistant_created_at=assistant_created_at,
|
||||
events=jsonable_encoder(capture_processor.events),
|
||||
usage=jsonable_encoder(usage),
|
||||
checkpoint=updated_checkpoint,
|
||||
gathered_context=jsonable_encoder(gathered_context),
|
||||
gathered_context=encoded_gathered_context,
|
||||
initial_context=jsonable_encoder(initial_context),
|
||||
state=(
|
||||
WorkflowRunState.COMPLETED.value
|
||||
|
|
|
|||
396
api/services/workflow/text_chat_session_service.py
Normal file
396
api/services/workflow/text_chat_session_service.py
Normal file
|
|
@ -0,0 +1,396 @@
|
|||
"""Service helpers for text-chat session lifecycle orchestration."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import WorkflowRunTextSessionModel
|
||||
from api.db.workflow_run_text_session_client import (
|
||||
WorkflowRunTextSessionRevisionConflictError,
|
||||
)
|
||||
from api.services.pricing.workflow_run_cost import build_workflow_run_cost_info
|
||||
from api.services.workflow.text_chat_logs import (
|
||||
build_text_chat_realtime_feedback_events,
|
||||
)
|
||||
from api.services.workflow.text_chat_runner import (
|
||||
default_text_chat_checkpoint,
|
||||
execute_text_chat_pending_turn,
|
||||
merge_text_chat_usage_info,
|
||||
normalize_text_chat_checkpoint,
|
||||
)
|
||||
|
||||
TEXT_CHAT_SESSION_VERSION = 1
|
||||
|
||||
|
||||
class TextChatSessionRevisionConflictError(Exception):
|
||||
def __init__(self, expected_revision: int, actual_revision: int):
|
||||
self.expected_revision = expected_revision
|
||||
self.actual_revision = actual_revision
|
||||
super().__init__(
|
||||
"Text chat session revision conflict: "
|
||||
f"expected {expected_revision}, found {actual_revision}"
|
||||
)
|
||||
|
||||
|
||||
class TextChatSessionExecutionError(Exception):
|
||||
"""Raised when the assistant turn fails to execute."""
|
||||
|
||||
|
||||
class TextChatPendingTurnLostError(Exception):
|
||||
"""Raised when the pending turn disappears before persistence completes."""
|
||||
|
||||
|
||||
class TextChatTurnNotFoundError(Exception):
|
||||
"""Raised when a requested rewind cursor does not exist in the session."""
|
||||
|
||||
|
||||
def default_text_chat_session_data() -> dict[str, Any]:
|
||||
return {
|
||||
"version": TEXT_CHAT_SESSION_VERSION,
|
||||
"status": "idle",
|
||||
"cursor_turn_id": None,
|
||||
"turns": [],
|
||||
"discarded_future": [],
|
||||
"simulator": {
|
||||
"enabled": False,
|
||||
"config": {},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def normalize_text_chat_session_data(
|
||||
session_data: dict[str, Any] | None,
|
||||
) -> dict[str, Any]:
|
||||
normalized = {
|
||||
**default_text_chat_session_data(),
|
||||
**(session_data or {}),
|
||||
}
|
||||
normalized["turns"] = list(normalized.get("turns") or [])
|
||||
normalized["discarded_future"] = list(normalized.get("discarded_future") or [])
|
||||
simulator = normalized.get("simulator") or {}
|
||||
normalized["simulator"] = {
|
||||
"enabled": bool(simulator.get("enabled", False)),
|
||||
"config": dict(simulator.get("config") or {}),
|
||||
}
|
||||
return normalized
|
||||
|
||||
|
||||
async def initialize_text_chat_session(
|
||||
*,
|
||||
run_id: int,
|
||||
text_session: WorkflowRunTextSessionModel,
|
||||
) -> WorkflowRunTextSessionModel:
|
||||
session_data = normalize_text_chat_session_data(text_session.session_data)
|
||||
checkpoint = normalize_text_chat_checkpoint(text_session.checkpoint)
|
||||
|
||||
session_data["turns"] = [build_pending_text_chat_turn(user_text=None)]
|
||||
session_data["status"] = "pending_assistant_turn"
|
||||
checkpoint["anchor_turn_id"] = latest_completed_text_chat_turn_id(
|
||||
session_data["turns"]
|
||||
)
|
||||
|
||||
try:
|
||||
await db_client.update_workflow_run_text_session(
|
||||
run_id,
|
||||
session_data=session_data,
|
||||
checkpoint=checkpoint,
|
||||
expected_revision=text_session.revision,
|
||||
)
|
||||
except WorkflowRunTextSessionRevisionConflictError as e:
|
||||
raise TextChatSessionRevisionConflictError(
|
||||
expected_revision=e.expected_revision,
|
||||
actual_revision=e.actual_revision,
|
||||
) from e
|
||||
|
||||
return await _reload_text_chat_session(run_id, text_session)
|
||||
|
||||
|
||||
async def append_text_chat_user_message(
|
||||
*,
|
||||
run_id: int,
|
||||
text_session: WorkflowRunTextSessionModel,
|
||||
user_text: str,
|
||||
expected_revision: int | None,
|
||||
) -> WorkflowRunTextSessionModel:
|
||||
session_data = normalize_text_chat_session_data(text_session.session_data)
|
||||
checkpoint = normalize_text_chat_checkpoint(text_session.checkpoint)
|
||||
|
||||
active_turns, discarded_future = truncate_text_chat_future_turns(session_data)
|
||||
active_turns.append(build_pending_text_chat_turn(user_text=user_text))
|
||||
|
||||
session_data["turns"] = active_turns
|
||||
session_data["discarded_future"] = discarded_future
|
||||
session_data["cursor_turn_id"] = None
|
||||
session_data["status"] = "pending_assistant_turn"
|
||||
checkpoint["anchor_turn_id"] = latest_completed_text_chat_turn_id(active_turns)
|
||||
|
||||
try:
|
||||
await db_client.update_workflow_run_text_session(
|
||||
run_id,
|
||||
session_data=session_data,
|
||||
checkpoint=checkpoint,
|
||||
expected_revision=expected_revision,
|
||||
)
|
||||
except WorkflowRunTextSessionRevisionConflictError as e:
|
||||
raise TextChatSessionRevisionConflictError(
|
||||
expected_revision=e.expected_revision,
|
||||
actual_revision=e.actual_revision,
|
||||
) from e
|
||||
|
||||
return await _reload_text_chat_session(run_id, text_session)
|
||||
|
||||
|
||||
async def rewind_text_chat_session_state(
|
||||
*,
|
||||
run_id: int,
|
||||
text_session: WorkflowRunTextSessionModel,
|
||||
cursor_turn_id: str | None,
|
||||
expected_revision: int | None,
|
||||
) -> WorkflowRunTextSessionModel:
|
||||
session_data = normalize_text_chat_session_data(text_session.session_data)
|
||||
validate_text_chat_turn_cursor(session_data, cursor_turn_id)
|
||||
|
||||
session_data["cursor_turn_id"] = cursor_turn_id
|
||||
session_data["status"] = "rewound" if cursor_turn_id else "idle"
|
||||
|
||||
try:
|
||||
await db_client.update_workflow_run_text_session(
|
||||
run_id,
|
||||
session_data=session_data,
|
||||
expected_revision=expected_revision,
|
||||
)
|
||||
except WorkflowRunTextSessionRevisionConflictError as e:
|
||||
raise TextChatSessionRevisionConflictError(
|
||||
expected_revision=e.expected_revision,
|
||||
actual_revision=e.actual_revision,
|
||||
) from e
|
||||
|
||||
await db_client.update_workflow_run(
|
||||
run_id,
|
||||
logs={
|
||||
"realtime_feedback_events": build_text_chat_realtime_feedback_events(
|
||||
session_data
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
return await _reload_text_chat_session(run_id, text_session)
|
||||
|
||||
|
||||
async def execute_pending_text_chat_turn(
|
||||
*,
|
||||
workflow_id: int,
|
||||
run_id: int,
|
||||
text_session: WorkflowRunTextSessionModel,
|
||||
) -> WorkflowRunTextSessionModel:
|
||||
"""Execute the current pending assistant turn and persist its side effects."""
|
||||
session_data = normalize_text_chat_session_data(text_session.session_data)
|
||||
checkpoint = normalize_text_chat_checkpoint(text_session.checkpoint)
|
||||
|
||||
try:
|
||||
execution = await execute_text_chat_pending_turn(
|
||||
workflow_run_id=run_id,
|
||||
workflow_id=workflow_id,
|
||||
session_data=session_data,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
except Exception as e:
|
||||
await _mark_pending_turn_failed(
|
||||
run_id=run_id,
|
||||
text_session=text_session,
|
||||
error_message=str(e),
|
||||
)
|
||||
raise TextChatSessionExecutionError(
|
||||
"Failed to execute text chat assistant turn"
|
||||
) from e
|
||||
|
||||
completed_session_data = normalize_text_chat_session_data(text_session.session_data)
|
||||
completed_turns = list(completed_session_data.get("turns") or [])
|
||||
if not completed_turns or completed_turns[-1].get("status") != "pending":
|
||||
raise TextChatPendingTurnLostError(
|
||||
"Text chat session lost its pending turn before completion"
|
||||
)
|
||||
|
||||
completed_turns[-1]["status"] = "completed"
|
||||
completed_turns[-1]["assistant_message"] = (
|
||||
{
|
||||
"text": execution.assistant_text,
|
||||
"created_at": execution.assistant_created_at,
|
||||
}
|
||||
if execution.assistant_text
|
||||
else None
|
||||
)
|
||||
completed_turns[-1]["events"] = execution.events
|
||||
completed_turns[-1]["usage"] = execution.usage
|
||||
completed_turns[-1]["checkpoint_after_turn"] = execution.checkpoint
|
||||
completed_session_data["turns"] = completed_turns
|
||||
completed_session_data["status"] = "idle"
|
||||
|
||||
try:
|
||||
await db_client.update_workflow_run_text_session(
|
||||
run_id,
|
||||
session_data=completed_session_data,
|
||||
checkpoint=execution.checkpoint,
|
||||
expected_revision=text_session.revision,
|
||||
)
|
||||
except WorkflowRunTextSessionRevisionConflictError as e:
|
||||
raise TextChatSessionRevisionConflictError(
|
||||
expected_revision=e.expected_revision,
|
||||
actual_revision=e.actual_revision,
|
||||
) from e
|
||||
|
||||
existing_usage_info = text_session.workflow_run.usage_info or {}
|
||||
merged_usage_info = merge_text_chat_usage_info(existing_usage_info, execution.usage)
|
||||
text_chat_logs = {
|
||||
"realtime_feedback_events": build_text_chat_realtime_feedback_events(
|
||||
completed_session_data
|
||||
)
|
||||
}
|
||||
await db_client.update_workflow_run(
|
||||
run_id,
|
||||
initial_context=execution.initial_context,
|
||||
usage_info=merged_usage_info,
|
||||
gathered_context=execution.gathered_context,
|
||||
logs=text_chat_logs,
|
||||
state=execution.state,
|
||||
is_completed=execution.is_completed,
|
||||
)
|
||||
workflow_run = await db_client.get_workflow_run_by_id(run_id)
|
||||
if workflow_run:
|
||||
cost_info = await build_workflow_run_cost_info(workflow_run)
|
||||
if cost_info is not None:
|
||||
await db_client.update_workflow_run(run_id, cost_info=cost_info)
|
||||
|
||||
return await _reload_text_chat_session(run_id, text_session)
|
||||
|
||||
|
||||
def validate_text_chat_turn_cursor(
|
||||
session_data: dict[str, Any],
|
||||
cursor_turn_id: str | None,
|
||||
) -> None:
|
||||
if cursor_turn_id is None:
|
||||
return
|
||||
if not any(turn.get("id") == cursor_turn_id for turn in session_data["turns"]):
|
||||
raise TextChatTurnNotFoundError("Turn not found in text chat session")
|
||||
|
||||
|
||||
def truncate_text_chat_future_turns(
|
||||
session_data: dict[str, Any],
|
||||
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
|
||||
cursor_turn_id = session_data.get("cursor_turn_id")
|
||||
turns = list(session_data.get("turns") or [])
|
||||
discarded_future = list(session_data.get("discarded_future") or [])
|
||||
|
||||
if cursor_turn_id is None:
|
||||
return turns, discarded_future
|
||||
|
||||
for index, turn in enumerate(turns):
|
||||
if turn.get("id") == cursor_turn_id:
|
||||
active_turns = turns[: index + 1]
|
||||
future_turns = turns[index + 1 :]
|
||||
if future_turns:
|
||||
discarded_future.append(
|
||||
{
|
||||
"rewound_from_turn_id": cursor_turn_id,
|
||||
"discarded_at": datetime.now(UTC).isoformat(),
|
||||
"turns": future_turns,
|
||||
}
|
||||
)
|
||||
return active_turns, discarded_future
|
||||
|
||||
raise TextChatTurnNotFoundError("Turn not found in text chat session")
|
||||
|
||||
|
||||
def latest_completed_text_chat_turn_id(turns: list[dict[str, Any]]) -> str | None:
|
||||
for turn in reversed(turns):
|
||||
if turn.get("status") == "completed":
|
||||
return turn.get("id")
|
||||
return None
|
||||
|
||||
|
||||
def build_pending_text_chat_turn(*, user_text: str | None) -> dict[str, Any]:
|
||||
now = datetime.now(UTC).isoformat()
|
||||
return {
|
||||
"id": f"turn_{uuid4().hex[:12]}",
|
||||
"status": "pending",
|
||||
"created_at": now,
|
||||
"user_message": (
|
||||
{
|
||||
"text": user_text,
|
||||
"created_at": now,
|
||||
}
|
||||
if user_text is not None
|
||||
else None
|
||||
),
|
||||
"assistant_message": None,
|
||||
"events": [],
|
||||
"usage": {},
|
||||
}
|
||||
|
||||
|
||||
async def _mark_pending_turn_failed(
|
||||
*,
|
||||
run_id: int,
|
||||
text_session: WorkflowRunTextSessionModel,
|
||||
error_message: str,
|
||||
) -> None:
|
||||
failed_session_data = normalize_text_chat_session_data(text_session.session_data)
|
||||
failed_turns = list(failed_session_data.get("turns") or [])
|
||||
if not failed_turns or failed_turns[-1].get("status") != "pending":
|
||||
return
|
||||
|
||||
failed_turns[-1]["status"] = "failed"
|
||||
failed_turns[-1]["events"] = [
|
||||
*(failed_turns[-1].get("events") or []),
|
||||
{
|
||||
"type": "execution_error",
|
||||
"created_at": datetime.now(UTC).isoformat(),
|
||||
"payload": {"message": error_message},
|
||||
},
|
||||
]
|
||||
failed_session_data["turns"] = failed_turns
|
||||
failed_session_data["status"] = "error"
|
||||
try:
|
||||
await db_client.update_workflow_run_text_session(
|
||||
run_id,
|
||||
session_data=failed_session_data,
|
||||
expected_revision=text_session.revision,
|
||||
)
|
||||
except WorkflowRunTextSessionRevisionConflictError:
|
||||
return
|
||||
|
||||
|
||||
async def _reload_text_chat_session(
|
||||
run_id: int,
|
||||
text_session: WorkflowRunTextSessionModel,
|
||||
) -> WorkflowRunTextSessionModel:
|
||||
organization_id = text_session.workflow_run.workflow.organization_id
|
||||
updated_text_session = await db_client.get_workflow_run_text_session(
|
||||
run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
if updated_text_session is None:
|
||||
raise TextChatSessionExecutionError("Text chat session not found after update")
|
||||
return updated_text_session
|
||||
|
||||
|
||||
__all__ = [
|
||||
"TEXT_CHAT_SESSION_VERSION",
|
||||
"TextChatTurnNotFoundError",
|
||||
"append_text_chat_user_message",
|
||||
"build_pending_text_chat_turn",
|
||||
"TextChatPendingTurnLostError",
|
||||
"TextChatSessionExecutionError",
|
||||
"TextChatSessionRevisionConflictError",
|
||||
"default_text_chat_checkpoint",
|
||||
"default_text_chat_session_data",
|
||||
"execute_pending_text_chat_turn",
|
||||
"initialize_text_chat_session",
|
||||
"latest_completed_text_chat_turn_id",
|
||||
"normalize_text_chat_checkpoint",
|
||||
"normalize_text_chat_session_data",
|
||||
"rewind_text_chat_session_state",
|
||||
"truncate_text_chat_future_turns",
|
||||
"validate_text_chat_turn_cursor",
|
||||
]
|
||||
Loading…
Add table
Add a link
Reference in a new issue