mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
* feat: add backend foundations * feat: add text chat UI * chore: simplify the reload behaviour * fix: fix upgrade banner to be triggered after package upload * feat: simplify TesterPanel design * chore: fix formatting and generate client * chore: fix tracing for text chat mode * fix: fix revert and edit CTA * refactor: refactor TesterPanel into smaller components * feat: enable runtime transition of nodes * fix: fix review comments
411 lines
14 KiB
Python
411 lines
14 KiB
Python
"""Service helpers for text-chat session lifecycle orchestration."""
|
|
|
|
from datetime import UTC, datetime
|
|
from typing import Any
|
|
from uuid import uuid4
|
|
|
|
from loguru import logger
|
|
|
|
from api.db import db_client
|
|
from api.db.models import WorkflowRunTextSessionModel
|
|
from api.db.workflow_run_text_session_client import (
|
|
WorkflowRunTextSessionRevisionConflictError,
|
|
)
|
|
from api.services.pricing.workflow_run_cost import (
|
|
apply_usage_delta_to_organization,
|
|
build_workflow_run_cost_info,
|
|
)
|
|
from api.services.workflow.text_chat_logs import (
|
|
build_text_chat_realtime_feedback_events,
|
|
)
|
|
from api.services.workflow.text_chat_runner import (
|
|
default_text_chat_checkpoint,
|
|
execute_text_chat_pending_turn,
|
|
merge_text_chat_usage_info,
|
|
normalize_text_chat_checkpoint,
|
|
)
|
|
|
|
TEXT_CHAT_SESSION_VERSION = 1
|
|
|
|
|
|
class TextChatSessionRevisionConflictError(Exception):
|
|
def __init__(self, expected_revision: int, actual_revision: int):
|
|
self.expected_revision = expected_revision
|
|
self.actual_revision = actual_revision
|
|
super().__init__(
|
|
"Text chat session revision conflict: "
|
|
f"expected {expected_revision}, found {actual_revision}"
|
|
)
|
|
|
|
|
|
class TextChatSessionExecutionError(Exception):
|
|
"""Raised when the assistant turn fails to execute."""
|
|
|
|
|
|
class TextChatPendingTurnLostError(Exception):
|
|
"""Raised when the pending turn disappears before persistence completes."""
|
|
|
|
|
|
class TextChatTurnNotFoundError(Exception):
|
|
"""Raised when a requested rewind cursor does not exist in the session."""
|
|
|
|
|
|
def default_text_chat_session_data() -> dict[str, Any]:
|
|
return {
|
|
"version": TEXT_CHAT_SESSION_VERSION,
|
|
"status": "idle",
|
|
"cursor_turn_id": None,
|
|
"turns": [],
|
|
"discarded_future": [],
|
|
"simulator": {
|
|
"enabled": False,
|
|
"config": {},
|
|
},
|
|
}
|
|
|
|
|
|
def normalize_text_chat_session_data(
|
|
session_data: dict[str, Any] | None,
|
|
) -> dict[str, Any]:
|
|
normalized = {
|
|
**default_text_chat_session_data(),
|
|
**(session_data or {}),
|
|
}
|
|
normalized["turns"] = list(normalized.get("turns") or [])
|
|
normalized["discarded_future"] = list(normalized.get("discarded_future") or [])
|
|
simulator = normalized.get("simulator") or {}
|
|
normalized["simulator"] = {
|
|
"enabled": bool(simulator.get("enabled", False)),
|
|
"config": dict(simulator.get("config") or {}),
|
|
}
|
|
return normalized
|
|
|
|
|
|
async def initialize_text_chat_session(
|
|
*,
|
|
run_id: int,
|
|
text_session: WorkflowRunTextSessionModel,
|
|
) -> WorkflowRunTextSessionModel:
|
|
session_data = normalize_text_chat_session_data(text_session.session_data)
|
|
checkpoint = normalize_text_chat_checkpoint(text_session.checkpoint)
|
|
|
|
session_data["turns"] = [build_pending_text_chat_turn(user_text=None)]
|
|
session_data["status"] = "pending_assistant_turn"
|
|
checkpoint["anchor_turn_id"] = latest_completed_text_chat_turn_id(
|
|
session_data["turns"]
|
|
)
|
|
|
|
try:
|
|
await db_client.update_workflow_run_text_session(
|
|
run_id,
|
|
session_data=session_data,
|
|
checkpoint=checkpoint,
|
|
expected_revision=text_session.revision,
|
|
)
|
|
except WorkflowRunTextSessionRevisionConflictError as e:
|
|
raise TextChatSessionRevisionConflictError(
|
|
expected_revision=e.expected_revision,
|
|
actual_revision=e.actual_revision,
|
|
) from e
|
|
|
|
return await _reload_text_chat_session(run_id)
|
|
|
|
|
|
async def append_text_chat_user_message(
|
|
*,
|
|
run_id: int,
|
|
text_session: WorkflowRunTextSessionModel,
|
|
user_text: str,
|
|
expected_revision: int | None,
|
|
) -> WorkflowRunTextSessionModel:
|
|
session_data = normalize_text_chat_session_data(text_session.session_data)
|
|
checkpoint = normalize_text_chat_checkpoint(text_session.checkpoint)
|
|
|
|
active_turns, discarded_future = truncate_text_chat_future_turns(session_data)
|
|
active_turns.append(build_pending_text_chat_turn(user_text=user_text))
|
|
|
|
session_data["turns"] = active_turns
|
|
session_data["discarded_future"] = discarded_future
|
|
session_data["cursor_turn_id"] = None
|
|
session_data["status"] = "pending_assistant_turn"
|
|
checkpoint["anchor_turn_id"] = latest_completed_text_chat_turn_id(active_turns)
|
|
|
|
try:
|
|
await db_client.update_workflow_run_text_session(
|
|
run_id,
|
|
session_data=session_data,
|
|
checkpoint=checkpoint,
|
|
expected_revision=expected_revision,
|
|
)
|
|
except WorkflowRunTextSessionRevisionConflictError as e:
|
|
raise TextChatSessionRevisionConflictError(
|
|
expected_revision=e.expected_revision,
|
|
actual_revision=e.actual_revision,
|
|
) from e
|
|
|
|
return await _reload_text_chat_session(run_id)
|
|
|
|
|
|
async def rewind_text_chat_session_state(
|
|
*,
|
|
run_id: int,
|
|
text_session: WorkflowRunTextSessionModel,
|
|
cursor_turn_id: str | None,
|
|
expected_revision: int | None,
|
|
) -> WorkflowRunTextSessionModel:
|
|
session_data = normalize_text_chat_session_data(text_session.session_data)
|
|
validate_text_chat_turn_cursor(session_data, cursor_turn_id)
|
|
|
|
session_data["cursor_turn_id"] = cursor_turn_id
|
|
session_data["status"] = "rewound" if cursor_turn_id else "idle"
|
|
|
|
try:
|
|
await db_client.update_workflow_run_text_session(
|
|
run_id,
|
|
session_data=session_data,
|
|
expected_revision=expected_revision,
|
|
)
|
|
except WorkflowRunTextSessionRevisionConflictError as e:
|
|
raise TextChatSessionRevisionConflictError(
|
|
expected_revision=e.expected_revision,
|
|
actual_revision=e.actual_revision,
|
|
) from e
|
|
|
|
await db_client.update_workflow_run(
|
|
run_id,
|
|
logs={
|
|
"realtime_feedback_events": build_text_chat_realtime_feedback_events(
|
|
session_data
|
|
)
|
|
},
|
|
)
|
|
|
|
return await _reload_text_chat_session(run_id)
|
|
|
|
|
|
async def execute_pending_text_chat_turn(
|
|
*,
|
|
workflow_id: int,
|
|
run_id: int,
|
|
text_session: WorkflowRunTextSessionModel,
|
|
) -> WorkflowRunTextSessionModel:
|
|
"""Execute the current pending assistant turn and persist its side effects."""
|
|
session_data = normalize_text_chat_session_data(text_session.session_data)
|
|
checkpoint = normalize_text_chat_checkpoint(text_session.checkpoint)
|
|
|
|
try:
|
|
execution = await execute_text_chat_pending_turn(
|
|
workflow_run_id=run_id,
|
|
workflow_id=workflow_id,
|
|
session_data=session_data,
|
|
checkpoint=checkpoint,
|
|
)
|
|
except Exception as e:
|
|
await _mark_pending_turn_failed(
|
|
run_id=run_id,
|
|
text_session=text_session,
|
|
error_message=str(e),
|
|
)
|
|
raise TextChatSessionExecutionError(
|
|
"Failed to execute text chat assistant turn"
|
|
) from e
|
|
|
|
completed_session_data = normalize_text_chat_session_data(text_session.session_data)
|
|
completed_turns = list(completed_session_data.get("turns") or [])
|
|
if not completed_turns or completed_turns[-1].get("status") != "pending":
|
|
raise TextChatPendingTurnLostError(
|
|
"Text chat session lost its pending turn before completion"
|
|
)
|
|
|
|
completed_turns[-1]["status"] = "completed"
|
|
completed_turns[-1]["assistant_message"] = (
|
|
{
|
|
"text": execution.assistant_text,
|
|
"created_at": execution.assistant_created_at,
|
|
}
|
|
if execution.assistant_text
|
|
else None
|
|
)
|
|
completed_turns[-1]["events"] = execution.events
|
|
completed_turns[-1]["usage"] = execution.usage
|
|
completed_turns[-1]["checkpoint_after_turn"] = execution.checkpoint
|
|
completed_session_data["turns"] = completed_turns
|
|
completed_session_data["status"] = "idle"
|
|
|
|
try:
|
|
await db_client.update_workflow_run_text_session(
|
|
run_id,
|
|
session_data=completed_session_data,
|
|
checkpoint=execution.checkpoint,
|
|
expected_revision=text_session.revision,
|
|
)
|
|
except WorkflowRunTextSessionRevisionConflictError as e:
|
|
raise TextChatSessionRevisionConflictError(
|
|
expected_revision=e.expected_revision,
|
|
actual_revision=e.actual_revision,
|
|
) from e
|
|
|
|
existing_usage_info = text_session.workflow_run.usage_info or {}
|
|
merged_usage_info = merge_text_chat_usage_info(existing_usage_info, execution.usage)
|
|
text_chat_logs = {
|
|
"realtime_feedback_events": build_text_chat_realtime_feedback_events(
|
|
completed_session_data
|
|
)
|
|
}
|
|
await db_client.update_workflow_run(
|
|
run_id,
|
|
initial_context=execution.initial_context,
|
|
usage_info=merged_usage_info,
|
|
gathered_context=execution.gathered_context,
|
|
logs=text_chat_logs,
|
|
state=execution.state,
|
|
is_completed=execution.is_completed,
|
|
)
|
|
workflow_run = await db_client.get_workflow_run_by_id(run_id)
|
|
if workflow_run:
|
|
try:
|
|
# Apply the per-turn delta so org usage tracks cumulative run cost
|
|
# without replaying the full session totals on every turn.
|
|
await apply_usage_delta_to_organization(workflow_run, execution.usage)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Failed to update organization usage for text chat run {run_id}: {e}"
|
|
)
|
|
|
|
cost_info = await build_workflow_run_cost_info(workflow_run)
|
|
if cost_info is not None:
|
|
await db_client.update_workflow_run(run_id, cost_info=cost_info)
|
|
|
|
return await _reload_text_chat_session(run_id)
|
|
|
|
|
|
def validate_text_chat_turn_cursor(
|
|
session_data: dict[str, Any],
|
|
cursor_turn_id: str | None,
|
|
) -> None:
|
|
if cursor_turn_id is None:
|
|
return
|
|
if not any(turn.get("id") == cursor_turn_id for turn in session_data["turns"]):
|
|
raise TextChatTurnNotFoundError("Turn not found in text chat session")
|
|
|
|
|
|
def truncate_text_chat_future_turns(
|
|
session_data: dict[str, Any],
|
|
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
|
|
cursor_turn_id = session_data.get("cursor_turn_id")
|
|
turns = list(session_data.get("turns") or [])
|
|
discarded_future = list(session_data.get("discarded_future") or [])
|
|
|
|
if cursor_turn_id is None:
|
|
return turns, discarded_future
|
|
|
|
for index, turn in enumerate(turns):
|
|
if turn.get("id") == cursor_turn_id:
|
|
active_turns = turns[: index + 1]
|
|
future_turns = turns[index + 1 :]
|
|
if future_turns:
|
|
discarded_future.append(
|
|
{
|
|
"rewound_from_turn_id": cursor_turn_id,
|
|
"discarded_at": datetime.now(UTC).isoformat(),
|
|
"turns": future_turns,
|
|
}
|
|
)
|
|
return active_turns, discarded_future
|
|
|
|
raise TextChatTurnNotFoundError("Turn not found in text chat session")
|
|
|
|
|
|
def latest_completed_text_chat_turn_id(turns: list[dict[str, Any]]) -> str | None:
|
|
for turn in reversed(turns):
|
|
if turn.get("status") == "completed":
|
|
return turn.get("id")
|
|
return None
|
|
|
|
|
|
def build_pending_text_chat_turn(*, user_text: str | None) -> dict[str, Any]:
|
|
now = datetime.now(UTC).isoformat()
|
|
return {
|
|
"id": f"turn_{uuid4().hex[:12]}",
|
|
"status": "pending",
|
|
"created_at": now,
|
|
"user_message": (
|
|
{
|
|
"text": user_text,
|
|
"created_at": now,
|
|
}
|
|
if user_text is not None
|
|
else None
|
|
),
|
|
"assistant_message": None,
|
|
"events": [],
|
|
"usage": {},
|
|
}
|
|
|
|
|
|
async def _mark_pending_turn_failed(
|
|
*,
|
|
run_id: int,
|
|
text_session: WorkflowRunTextSessionModel,
|
|
error_message: str,
|
|
) -> None:
|
|
failed_session_data = normalize_text_chat_session_data(text_session.session_data)
|
|
failed_turns = list(failed_session_data.get("turns") or [])
|
|
if not failed_turns or failed_turns[-1].get("status") != "pending":
|
|
return
|
|
|
|
failed_turns[-1]["status"] = "failed"
|
|
failed_turns[-1]["events"] = [
|
|
*(failed_turns[-1].get("events") or []),
|
|
{
|
|
"type": "execution_error",
|
|
"created_at": datetime.now(UTC).isoformat(),
|
|
"payload": {"message": error_message},
|
|
},
|
|
]
|
|
failed_session_data["turns"] = failed_turns
|
|
failed_session_data["status"] = "error"
|
|
try:
|
|
await db_client.update_workflow_run_text_session(
|
|
run_id,
|
|
session_data=failed_session_data,
|
|
expected_revision=text_session.revision,
|
|
)
|
|
except WorkflowRunTextSessionRevisionConflictError:
|
|
return
|
|
|
|
|
|
async def _reload_text_chat_session(run_id: int) -> WorkflowRunTextSessionModel:
|
|
organization_id = await db_client.get_organization_id_by_workflow_run_id(run_id)
|
|
if organization_id is None:
|
|
raise TextChatSessionExecutionError(
|
|
"Workflow run organization not found after update"
|
|
)
|
|
updated_text_session = await db_client.get_workflow_run_text_session(
|
|
run_id,
|
|
organization_id=organization_id,
|
|
)
|
|
if updated_text_session is None:
|
|
raise TextChatSessionExecutionError("Text chat session not found after update")
|
|
return updated_text_session
|
|
|
|
|
|
__all__ = [
|
|
"TEXT_CHAT_SESSION_VERSION",
|
|
"TextChatTurnNotFoundError",
|
|
"append_text_chat_user_message",
|
|
"build_pending_text_chat_turn",
|
|
"TextChatPendingTurnLostError",
|
|
"TextChatSessionExecutionError",
|
|
"TextChatSessionRevisionConflictError",
|
|
"default_text_chat_checkpoint",
|
|
"default_text_chat_session_data",
|
|
"execute_pending_text_chat_turn",
|
|
"initialize_text_chat_session",
|
|
"latest_completed_text_chat_turn_id",
|
|
"normalize_text_chat_checkpoint",
|
|
"normalize_text_chat_session_data",
|
|
"rewind_text_chat_session_state",
|
|
"truncate_text_chat_future_turns",
|
|
"validate_text_chat_turn_cursor",
|
|
]
|