mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-25 08:48:13 +02:00
chore: fix tracing for text chat mode
This commit is contained in:
parent
e23cce444f
commit
08a2435ba5
31 changed files with 1753 additions and 597 deletions
|
|
@ -57,6 +57,7 @@ class WorkflowRunUsageResponse(BaseModel):
|
|||
caller_number: Optional[str] = None
|
||||
called_number: Optional[str] = None
|
||||
call_type: Optional[str] = None
|
||||
mode: Optional[str] = None
|
||||
disposition: Optional[str] = None
|
||||
initial_context: Optional[Dict[str, Any]] = None
|
||||
gathered_context: Optional[Dict[str, Any]] = None
|
||||
|
|
|
|||
|
|
@ -1,27 +1,32 @@
|
|||
from datetime import UTC, datetime
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pipecat.utils.run_context import set_current_run_id
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel, WorkflowRunTextSessionModel
|
||||
from api.db.workflow_run_text_session_client import (
|
||||
WorkflowRunTextSessionRevisionConflictError,
|
||||
)
|
||||
from api.enums import WorkflowRunMode
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.workflow.text_chat_runner import (
|
||||
execute_text_chat_pending_turn,
|
||||
merge_text_chat_usage_info,
|
||||
from api.services.workflow.text_chat_session_service import (
|
||||
TextChatPendingTurnLostError,
|
||||
TextChatSessionExecutionError,
|
||||
TextChatSessionRevisionConflictError,
|
||||
TextChatTurnNotFoundError,
|
||||
append_text_chat_user_message,
|
||||
default_text_chat_checkpoint,
|
||||
default_text_chat_session_data,
|
||||
execute_pending_text_chat_turn,
|
||||
initialize_text_chat_session,
|
||||
normalize_text_chat_checkpoint,
|
||||
normalize_text_chat_session_data,
|
||||
rewind_text_chat_session_state,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/workflow", tags=["workflow-text-chat"])
|
||||
|
||||
TEXT_CHAT_SESSION_VERSION = 1
|
||||
TEXT_CHAT_CHECKPOINT_VERSION = 1
|
||||
|
||||
|
||||
class CreateTextChatSessionRequest(BaseModel):
|
||||
name: str | None = None
|
||||
|
|
@ -56,57 +61,6 @@ class WorkflowRunTextSessionResponse(BaseModel):
|
|||
updated_at: datetime | None = None
|
||||
|
||||
|
||||
def _default_session_data() -> Dict[str, Any]:
|
||||
return {
|
||||
"version": TEXT_CHAT_SESSION_VERSION,
|
||||
"status": "idle",
|
||||
"cursor_turn_id": None,
|
||||
"turns": [],
|
||||
"discarded_future": [],
|
||||
"simulator": {
|
||||
"enabled": False,
|
||||
"config": {},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _default_checkpoint() -> Dict[str, Any]:
|
||||
return {
|
||||
"version": TEXT_CHAT_CHECKPOINT_VERSION,
|
||||
"anchor_turn_id": None,
|
||||
"current_node_id": None,
|
||||
"messages": [],
|
||||
"gathered_context": {},
|
||||
"tool_state": {},
|
||||
}
|
||||
|
||||
|
||||
def _normalize_session_data(session_data: Dict[str, Any] | None) -> Dict[str, Any]:
|
||||
normalized = {
|
||||
**_default_session_data(),
|
||||
**(session_data or {}),
|
||||
}
|
||||
normalized["turns"] = list(normalized.get("turns") or [])
|
||||
normalized["discarded_future"] = list(normalized.get("discarded_future") or [])
|
||||
simulator = normalized.get("simulator") or {}
|
||||
normalized["simulator"] = {
|
||||
"enabled": bool(simulator.get("enabled", False)),
|
||||
"config": dict(simulator.get("config") or {}),
|
||||
}
|
||||
return normalized
|
||||
|
||||
|
||||
def _normalize_checkpoint(checkpoint: Dict[str, Any] | None) -> Dict[str, Any]:
|
||||
normalized = {
|
||||
**_default_checkpoint(),
|
||||
**(checkpoint or {}),
|
||||
}
|
||||
normalized["messages"] = list(normalized.get("messages") or [])
|
||||
normalized["gathered_context"] = dict(normalized.get("gathered_context") or {})
|
||||
normalized["tool_state"] = dict(normalized.get("tool_state") or {})
|
||||
return normalized
|
||||
|
||||
|
||||
def _get_state_value(state: Any) -> str:
|
||||
return state.value if hasattr(state, "value") else str(state)
|
||||
|
||||
|
|
@ -126,81 +80,14 @@ def _build_response(
|
|||
initial_context=workflow_run.initial_context,
|
||||
gathered_context=workflow_run.gathered_context,
|
||||
annotations=workflow_run.annotations,
|
||||
session_data=_normalize_session_data(text_session.session_data),
|
||||
checkpoint=_normalize_checkpoint(text_session.checkpoint),
|
||||
session_data=normalize_text_chat_session_data(text_session.session_data),
|
||||
checkpoint=normalize_text_chat_checkpoint(text_session.checkpoint),
|
||||
created_at=text_session.created_at,
|
||||
updated_at=text_session.updated_at,
|
||||
)
|
||||
|
||||
|
||||
def _validate_turn_cursor(
|
||||
session_data: Dict[str, Any], cursor_turn_id: str | None
|
||||
) -> None:
|
||||
if cursor_turn_id is None:
|
||||
return
|
||||
if not any(turn.get("id") == cursor_turn_id for turn in session_data["turns"]):
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Turn not found in text chat session"
|
||||
)
|
||||
|
||||
|
||||
def _truncate_future_turns(
|
||||
session_data: Dict[str, Any],
|
||||
) -> tuple[list[Dict[str, Any]], list[Dict[str, Any]]]:
|
||||
cursor_turn_id = session_data.get("cursor_turn_id")
|
||||
turns = list(session_data.get("turns") or [])
|
||||
discarded_future = list(session_data.get("discarded_future") or [])
|
||||
|
||||
if cursor_turn_id is None:
|
||||
return turns, discarded_future
|
||||
|
||||
for index, turn in enumerate(turns):
|
||||
if turn.get("id") == cursor_turn_id:
|
||||
active_turns = turns[: index + 1]
|
||||
future_turns = turns[index + 1 :]
|
||||
if future_turns:
|
||||
discarded_future.append(
|
||||
{
|
||||
"rewound_from_turn_id": cursor_turn_id,
|
||||
"discarded_at": datetime.now(UTC).isoformat(),
|
||||
"turns": future_turns,
|
||||
}
|
||||
)
|
||||
return active_turns, discarded_future
|
||||
|
||||
raise HTTPException(status_code=404, detail="Turn not found in text chat session")
|
||||
|
||||
|
||||
def _latest_completed_turn_id(turns: list[Dict[str, Any]]) -> str | None:
|
||||
for turn in reversed(turns):
|
||||
if turn.get("status") == "completed":
|
||||
return turn.get("id")
|
||||
return None
|
||||
|
||||
|
||||
def _build_pending_turn(*, user_text: str | None) -> Dict[str, Any]:
|
||||
now = datetime.now(UTC).isoformat()
|
||||
return {
|
||||
"id": f"turn_{uuid4().hex[:12]}",
|
||||
"status": "pending",
|
||||
"created_at": now,
|
||||
"user_message": (
|
||||
{
|
||||
"text": user_text,
|
||||
"created_at": now,
|
||||
}
|
||||
if user_text is not None
|
||||
else None
|
||||
),
|
||||
"assistant_message": None,
|
||||
"events": [],
|
||||
"usage": {},
|
||||
}
|
||||
|
||||
|
||||
def _revision_conflict_detail(
|
||||
e: WorkflowRunTextSessionRevisionConflictError,
|
||||
) -> dict[str, Any]:
|
||||
def _revision_conflict_detail(e: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"message": "Text chat session revision conflict",
|
||||
"expected_revision": e.expected_revision,
|
||||
|
|
@ -213,6 +100,7 @@ async def _load_text_session_or_404(
|
|||
run_id: int,
|
||||
user: UserModel,
|
||||
) -> WorkflowRunTextSessionModel:
|
||||
set_current_run_id(run_id)
|
||||
if user.selected_organization_id is None:
|
||||
raise HTTPException(status_code=403, detail="Organization context is required")
|
||||
text_session = await db_client.get_workflow_run_text_session(
|
||||
|
|
@ -229,96 +117,26 @@ async def _load_text_session_or_404(
|
|||
return text_session
|
||||
|
||||
|
||||
async def _execute_pending_turn_and_build_response(
|
||||
async def _execute_pending_turn_response(
|
||||
*,
|
||||
workflow_id: int,
|
||||
run_id: int,
|
||||
text_session: WorkflowRunTextSessionModel,
|
||||
user: UserModel,
|
||||
) -> WorkflowRunTextSessionResponse:
|
||||
try:
|
||||
execution = await execute_text_chat_pending_turn(
|
||||
workflow_run_id=run_id,
|
||||
updated_text_session = await execute_pending_text_chat_turn(
|
||||
workflow_id=workflow_id,
|
||||
session_data=_normalize_session_data(text_session.session_data),
|
||||
checkpoint=_normalize_checkpoint(text_session.checkpoint),
|
||||
run_id=run_id,
|
||||
text_session=text_session,
|
||||
)
|
||||
except Exception as e:
|
||||
failed_session_data = _normalize_session_data(text_session.session_data)
|
||||
failed_turns = list(failed_session_data.get("turns") or [])
|
||||
if failed_turns and failed_turns[-1].get("status") == "pending":
|
||||
failed_turns[-1]["status"] = "failed"
|
||||
failed_turns[-1]["events"] = [
|
||||
*(failed_turns[-1].get("events") or []),
|
||||
{
|
||||
"type": "execution_error",
|
||||
"created_at": datetime.now(UTC).isoformat(),
|
||||
"payload": {"message": str(e)},
|
||||
},
|
||||
]
|
||||
failed_session_data["turns"] = failed_turns
|
||||
failed_session_data["status"] = "error"
|
||||
try:
|
||||
await db_client.update_workflow_run_text_session(
|
||||
run_id,
|
||||
session_data=failed_session_data,
|
||||
expected_revision=text_session.revision,
|
||||
)
|
||||
except WorkflowRunTextSessionRevisionConflictError:
|
||||
pass
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to execute text chat assistant turn"
|
||||
)
|
||||
|
||||
completed_session_data = _normalize_session_data(text_session.session_data)
|
||||
completed_turns = list(completed_session_data.get("turns") or [])
|
||||
if not completed_turns or completed_turns[-1].get("status") != "pending":
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Text chat session lost its pending turn before completion",
|
||||
)
|
||||
|
||||
completed_turns[-1]["status"] = "completed"
|
||||
completed_turns[-1]["assistant_message"] = (
|
||||
{
|
||||
"text": execution.assistant_text,
|
||||
"created_at": execution.assistant_created_at,
|
||||
}
|
||||
if execution.assistant_text
|
||||
else None
|
||||
)
|
||||
completed_turns[-1]["events"] = execution.events
|
||||
completed_turns[-1]["usage"] = execution.usage
|
||||
completed_turns[-1]["checkpoint_after_turn"] = execution.checkpoint
|
||||
completed_session_data["turns"] = completed_turns
|
||||
completed_session_data["status"] = "idle"
|
||||
|
||||
try:
|
||||
await db_client.update_workflow_run_text_session(
|
||||
run_id,
|
||||
session_data=completed_session_data,
|
||||
checkpoint=execution.checkpoint,
|
||||
expected_revision=text_session.revision,
|
||||
)
|
||||
except WorkflowRunTextSessionRevisionConflictError as e:
|
||||
except TextChatSessionRevisionConflictError as e:
|
||||
raise HTTPException(status_code=409, detail=_revision_conflict_detail(e))
|
||||
except TextChatPendingTurnLostError as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
except TextChatSessionExecutionError as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
existing_usage_info = text_session.workflow_run.usage_info or {}
|
||||
merged_usage_info = merge_text_chat_usage_info(
|
||||
existing_usage_info,
|
||||
execution.usage,
|
||||
)
|
||||
await db_client.update_workflow_run(
|
||||
run_id,
|
||||
initial_context=execution.initial_context,
|
||||
usage_info=merged_usage_info,
|
||||
gathered_context=execution.gathered_context,
|
||||
state=execution.state,
|
||||
is_completed=execution.is_completed,
|
||||
)
|
||||
|
||||
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
|
||||
return _build_response(text_session)
|
||||
return _build_response(updated_text_session)
|
||||
|
||||
|
||||
@router.post(
|
||||
|
|
@ -343,6 +161,8 @@ async def create_text_chat_session(
|
|||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
set_current_run_id(workflow_run.id)
|
||||
|
||||
annotations = {
|
||||
"tester": {
|
||||
"source": "workflow_editor",
|
||||
|
|
@ -358,32 +178,22 @@ async def create_text_chat_session(
|
|||
|
||||
text_session = await db_client.ensure_workflow_run_text_session(
|
||||
workflow_run.id,
|
||||
session_data=_default_session_data(),
|
||||
checkpoint=_default_checkpoint(),
|
||||
session_data=default_text_chat_session_data(),
|
||||
checkpoint=default_text_chat_checkpoint(),
|
||||
)
|
||||
session_data = _normalize_session_data(text_session.session_data)
|
||||
checkpoint = _normalize_checkpoint(text_session.checkpoint)
|
||||
|
||||
session_data["turns"] = [_build_pending_turn(user_text=None)]
|
||||
session_data["status"] = "pending_assistant_turn"
|
||||
checkpoint["anchor_turn_id"] = _latest_completed_turn_id(session_data["turns"])
|
||||
|
||||
try:
|
||||
await db_client.update_workflow_run_text_session(
|
||||
workflow_run.id,
|
||||
session_data=session_data,
|
||||
checkpoint=checkpoint,
|
||||
expected_revision=text_session.revision,
|
||||
text_session = await initialize_text_chat_session(
|
||||
run_id=workflow_run.id,
|
||||
text_session=text_session,
|
||||
)
|
||||
except WorkflowRunTextSessionRevisionConflictError as e:
|
||||
except TextChatSessionRevisionConflictError as e:
|
||||
raise HTTPException(status_code=409, detail=_revision_conflict_detail(e))
|
||||
|
||||
text_session = await _load_text_session_or_404(workflow_id, workflow_run.id, user)
|
||||
return await _execute_pending_turn_and_build_response(
|
||||
return await _execute_pending_turn_response(
|
||||
workflow_id=workflow_id,
|
||||
run_id=workflow_run.id,
|
||||
text_session=text_session,
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -411,34 +221,20 @@ async def append_text_chat_message(
|
|||
user: UserModel = Depends(get_user),
|
||||
) -> WorkflowRunTextSessionResponse:
|
||||
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
|
||||
session_data = _normalize_session_data(text_session.session_data)
|
||||
checkpoint = _normalize_checkpoint(text_session.checkpoint)
|
||||
|
||||
active_turns, discarded_future = _truncate_future_turns(session_data)
|
||||
active_turns.append(_build_pending_turn(user_text=request.text))
|
||||
|
||||
session_data["turns"] = active_turns
|
||||
session_data["discarded_future"] = discarded_future
|
||||
session_data["cursor_turn_id"] = None
|
||||
session_data["status"] = "pending_assistant_turn"
|
||||
checkpoint["anchor_turn_id"] = _latest_completed_turn_id(active_turns)
|
||||
|
||||
try:
|
||||
await db_client.update_workflow_run_text_session(
|
||||
run_id,
|
||||
session_data=session_data,
|
||||
checkpoint=checkpoint,
|
||||
text_session = await append_text_chat_user_message(
|
||||
run_id=run_id,
|
||||
text_session=text_session,
|
||||
user_text=request.text,
|
||||
expected_revision=request.expected_revision,
|
||||
)
|
||||
except WorkflowRunTextSessionRevisionConflictError as e:
|
||||
except TextChatSessionRevisionConflictError as e:
|
||||
raise HTTPException(status_code=409, detail=_revision_conflict_detail(e))
|
||||
|
||||
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
|
||||
return await _execute_pending_turn_and_build_response(
|
||||
return await _execute_pending_turn_response(
|
||||
workflow_id=workflow_id,
|
||||
run_id=run_id,
|
||||
text_session=text_session,
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -453,27 +249,16 @@ async def rewind_text_chat_session(
|
|||
user: UserModel = Depends(get_user),
|
||||
) -> WorkflowRunTextSessionResponse:
|
||||
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
|
||||
session_data = _normalize_session_data(text_session.session_data)
|
||||
_validate_turn_cursor(session_data, request.cursor_turn_id)
|
||||
|
||||
session_data["cursor_turn_id"] = request.cursor_turn_id
|
||||
session_data["status"] = "rewound" if request.cursor_turn_id else "idle"
|
||||
|
||||
try:
|
||||
await db_client.update_workflow_run_text_session(
|
||||
run_id,
|
||||
session_data=session_data,
|
||||
text_session = await rewind_text_chat_session_state(
|
||||
run_id=run_id,
|
||||
text_session=text_session,
|
||||
cursor_turn_id=request.cursor_turn_id,
|
||||
expected_revision=request.expected_revision,
|
||||
)
|
||||
except WorkflowRunTextSessionRevisionConflictError as e:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail={
|
||||
"message": "Text chat session revision conflict",
|
||||
"expected_revision": e.expected_revision,
|
||||
"actual_revision": e.actual_revision,
|
||||
},
|
||||
)
|
||||
except TextChatTurnNotFoundError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except TextChatSessionRevisionConflictError as e:
|
||||
raise HTTPException(status_code=409, detail=_revision_conflict_detail(e))
|
||||
|
||||
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
|
||||
return _build_response(text_session)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue