dograh/api/routes/workflow_text_chat.py
2026-05-21 07:47:58 +05:30

479 lines
16 KiB
Python

from datetime import UTC, datetime
from typing import Any, Dict
from uuid import uuid4
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from api.db import db_client
from api.db.models import UserModel, WorkflowRunTextSessionModel
from api.db.workflow_run_text_session_client import (
WorkflowRunTextSessionRevisionConflictError,
)
from api.enums import WorkflowRunMode
from api.services.auth.depends import get_user
from api.services.workflow.text_chat_runner import (
execute_text_chat_pending_turn,
merge_text_chat_usage_info,
)
router = APIRouter(prefix="/workflow", tags=["workflow-text-chat"])
TEXT_CHAT_SESSION_VERSION = 1
TEXT_CHAT_CHECKPOINT_VERSION = 1
class CreateTextChatSessionRequest(BaseModel):
name: str | None = None
initial_context: Dict[str, Any] | None = None
annotations: Dict[str, Any] | None = None
class AppendTextChatMessageRequest(BaseModel):
text: str = Field(min_length=1)
expected_revision: int | None = None
class RewindTextChatSessionRequest(BaseModel):
cursor_turn_id: str | None = None
expected_revision: int | None = None
class WorkflowRunTextSessionResponse(BaseModel):
workflow_run_id: int
workflow_id: int
name: str
mode: str
state: str
is_completed: bool
revision: int
initial_context: Dict[str, Any] | None = None
gathered_context: Dict[str, Any] | None = None
annotations: Dict[str, Any] | None = None
session_data: Dict[str, Any]
checkpoint: Dict[str, Any]
created_at: datetime
updated_at: datetime | None = None
def _default_session_data() -> Dict[str, Any]:
return {
"version": TEXT_CHAT_SESSION_VERSION,
"status": "idle",
"cursor_turn_id": None,
"turns": [],
"discarded_future": [],
"simulator": {
"enabled": False,
"config": {},
},
}
def _default_checkpoint() -> Dict[str, Any]:
return {
"version": TEXT_CHAT_CHECKPOINT_VERSION,
"anchor_turn_id": None,
"current_node_id": None,
"messages": [],
"gathered_context": {},
"tool_state": {},
}
def _normalize_session_data(session_data: Dict[str, Any] | None) -> Dict[str, Any]:
normalized = {
**_default_session_data(),
**(session_data or {}),
}
normalized["turns"] = list(normalized.get("turns") or [])
normalized["discarded_future"] = list(normalized.get("discarded_future") or [])
simulator = normalized.get("simulator") or {}
normalized["simulator"] = {
"enabled": bool(simulator.get("enabled", False)),
"config": dict(simulator.get("config") or {}),
}
return normalized
def _normalize_checkpoint(checkpoint: Dict[str, Any] | None) -> Dict[str, Any]:
normalized = {
**_default_checkpoint(),
**(checkpoint or {}),
}
normalized["messages"] = list(normalized.get("messages") or [])
normalized["gathered_context"] = dict(normalized.get("gathered_context") or {})
normalized["tool_state"] = dict(normalized.get("tool_state") or {})
return normalized
def _get_state_value(state: Any) -> str:
return state.value if hasattr(state, "value") else str(state)
def _build_response(
text_session: WorkflowRunTextSessionModel,
) -> WorkflowRunTextSessionResponse:
workflow_run = text_session.workflow_run
return WorkflowRunTextSessionResponse(
workflow_run_id=workflow_run.id,
workflow_id=workflow_run.workflow_id,
name=workflow_run.name,
mode=workflow_run.mode,
state=_get_state_value(workflow_run.state),
is_completed=workflow_run.is_completed,
revision=text_session.revision,
initial_context=workflow_run.initial_context,
gathered_context=workflow_run.gathered_context,
annotations=workflow_run.annotations,
session_data=_normalize_session_data(text_session.session_data),
checkpoint=_normalize_checkpoint(text_session.checkpoint),
created_at=text_session.created_at,
updated_at=text_session.updated_at,
)
def _validate_turn_cursor(
session_data: Dict[str, Any], cursor_turn_id: str | None
) -> None:
if cursor_turn_id is None:
return
if not any(turn.get("id") == cursor_turn_id for turn in session_data["turns"]):
raise HTTPException(
status_code=404, detail="Turn not found in text chat session"
)
def _truncate_future_turns(
session_data: Dict[str, Any],
) -> tuple[list[Dict[str, Any]], list[Dict[str, Any]]]:
cursor_turn_id = session_data.get("cursor_turn_id")
turns = list(session_data.get("turns") or [])
discarded_future = list(session_data.get("discarded_future") or [])
if cursor_turn_id is None:
return turns, discarded_future
for index, turn in enumerate(turns):
if turn.get("id") == cursor_turn_id:
active_turns = turns[: index + 1]
future_turns = turns[index + 1 :]
if future_turns:
discarded_future.append(
{
"rewound_from_turn_id": cursor_turn_id,
"discarded_at": datetime.now(UTC).isoformat(),
"turns": future_turns,
}
)
return active_turns, discarded_future
raise HTTPException(status_code=404, detail="Turn not found in text chat session")
def _latest_completed_turn_id(turns: list[Dict[str, Any]]) -> str | None:
for turn in reversed(turns):
if turn.get("status") == "completed":
return turn.get("id")
return None
def _build_pending_turn(*, user_text: str | None) -> Dict[str, Any]:
now = datetime.now(UTC).isoformat()
return {
"id": f"turn_{uuid4().hex[:12]}",
"status": "pending",
"created_at": now,
"user_message": (
{
"text": user_text,
"created_at": now,
}
if user_text is not None
else None
),
"assistant_message": None,
"events": [],
"usage": {},
}
def _revision_conflict_detail(
e: WorkflowRunTextSessionRevisionConflictError,
) -> dict[str, Any]:
return {
"message": "Text chat session revision conflict",
"expected_revision": e.expected_revision,
"actual_revision": e.actual_revision,
}
async def _load_text_session_or_404(
workflow_id: int,
run_id: int,
user: UserModel,
) -> WorkflowRunTextSessionModel:
if user.selected_organization_id is None:
raise HTTPException(status_code=403, detail="Organization context is required")
text_session = await db_client.get_workflow_run_text_session(
run_id, organization_id=user.selected_organization_id
)
if not text_session or not text_session.workflow_run:
raise HTTPException(status_code=404, detail="Text chat session not found")
if text_session.workflow_run.workflow_id != workflow_id:
raise HTTPException(status_code=404, detail="Text chat session not found")
if text_session.workflow_run.mode != WorkflowRunMode.TEXTCHAT.value:
raise HTTPException(
status_code=400, detail="Workflow run is not a text chat session"
)
return text_session
async def _execute_pending_turn_and_build_response(
*,
workflow_id: int,
run_id: int,
text_session: WorkflowRunTextSessionModel,
user: UserModel,
) -> WorkflowRunTextSessionResponse:
try:
execution = await execute_text_chat_pending_turn(
workflow_run_id=run_id,
workflow_id=workflow_id,
session_data=_normalize_session_data(text_session.session_data),
checkpoint=_normalize_checkpoint(text_session.checkpoint),
)
except Exception as e:
failed_session_data = _normalize_session_data(text_session.session_data)
failed_turns = list(failed_session_data.get("turns") or [])
if failed_turns and failed_turns[-1].get("status") == "pending":
failed_turns[-1]["status"] = "failed"
failed_turns[-1]["events"] = [
*(failed_turns[-1].get("events") or []),
{
"type": "execution_error",
"created_at": datetime.now(UTC).isoformat(),
"payload": {"message": str(e)},
},
]
failed_session_data["turns"] = failed_turns
failed_session_data["status"] = "error"
try:
await db_client.update_workflow_run_text_session(
run_id,
session_data=failed_session_data,
expected_revision=text_session.revision,
)
except WorkflowRunTextSessionRevisionConflictError:
pass
raise HTTPException(
status_code=500, detail="Failed to execute text chat assistant turn"
)
completed_session_data = _normalize_session_data(text_session.session_data)
completed_turns = list(completed_session_data.get("turns") or [])
if not completed_turns or completed_turns[-1].get("status") != "pending":
raise HTTPException(
status_code=500,
detail="Text chat session lost its pending turn before completion",
)
completed_turns[-1]["status"] = "completed"
completed_turns[-1]["assistant_message"] = (
{
"text": execution.assistant_text,
"created_at": execution.assistant_created_at,
}
if execution.assistant_text
else None
)
completed_turns[-1]["events"] = execution.events
completed_turns[-1]["usage"] = execution.usage
completed_turns[-1]["checkpoint_after_turn"] = execution.checkpoint
completed_session_data["turns"] = completed_turns
completed_session_data["status"] = "idle"
try:
await db_client.update_workflow_run_text_session(
run_id,
session_data=completed_session_data,
checkpoint=execution.checkpoint,
expected_revision=text_session.revision,
)
except WorkflowRunTextSessionRevisionConflictError as e:
raise HTTPException(status_code=409, detail=_revision_conflict_detail(e))
existing_usage_info = text_session.workflow_run.usage_info or {}
merged_usage_info = merge_text_chat_usage_info(
existing_usage_info,
execution.usage,
)
await db_client.update_workflow_run(
run_id,
initial_context=execution.initial_context,
usage_info=merged_usage_info,
gathered_context=execution.gathered_context,
state=execution.state,
is_completed=execution.is_completed,
)
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
return _build_response(text_session)
@router.post(
"/{workflow_id}/text-chat/sessions",
response_model=WorkflowRunTextSessionResponse,
)
async def create_text_chat_session(
workflow_id: int,
request: CreateTextChatSessionRequest,
user: UserModel = Depends(get_user),
) -> WorkflowRunTextSessionResponse:
session_name = request.name or f"WR-TEXT-{uuid4().hex[:6].upper()}"
try:
workflow_run = await db_client.create_workflow_run(
name=session_name,
workflow_id=workflow_id,
mode=WorkflowRunMode.TEXTCHAT.value,
user_id=user.id,
initial_context=request.initial_context,
use_draft=True,
)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
annotations = {
"tester": {
"source": "workflow_editor",
"modality": "text",
}
}
if request.annotations:
annotations = {**annotations, **request.annotations}
workflow_run = await db_client.update_workflow_run(
workflow_run.id,
annotations=annotations,
)
text_session = await db_client.ensure_workflow_run_text_session(
workflow_run.id,
session_data=_default_session_data(),
checkpoint=_default_checkpoint(),
)
session_data = _normalize_session_data(text_session.session_data)
checkpoint = _normalize_checkpoint(text_session.checkpoint)
session_data["turns"] = [_build_pending_turn(user_text=None)]
session_data["status"] = "pending_assistant_turn"
checkpoint["anchor_turn_id"] = _latest_completed_turn_id(session_data["turns"])
try:
await db_client.update_workflow_run_text_session(
workflow_run.id,
session_data=session_data,
checkpoint=checkpoint,
expected_revision=text_session.revision,
)
except WorkflowRunTextSessionRevisionConflictError as e:
raise HTTPException(status_code=409, detail=_revision_conflict_detail(e))
text_session = await _load_text_session_or_404(workflow_id, workflow_run.id, user)
return await _execute_pending_turn_and_build_response(
workflow_id=workflow_id,
run_id=workflow_run.id,
text_session=text_session,
user=user,
)
@router.get(
"/{workflow_id}/text-chat/sessions/{run_id}",
response_model=WorkflowRunTextSessionResponse,
)
async def get_text_chat_session(
workflow_id: int,
run_id: int,
user: UserModel = Depends(get_user),
) -> WorkflowRunTextSessionResponse:
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
return _build_response(text_session)
@router.post(
"/{workflow_id}/text-chat/sessions/{run_id}/messages",
response_model=WorkflowRunTextSessionResponse,
)
async def append_text_chat_message(
workflow_id: int,
run_id: int,
request: AppendTextChatMessageRequest,
user: UserModel = Depends(get_user),
) -> WorkflowRunTextSessionResponse:
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
session_data = _normalize_session_data(text_session.session_data)
checkpoint = _normalize_checkpoint(text_session.checkpoint)
active_turns, discarded_future = _truncate_future_turns(session_data)
active_turns.append(_build_pending_turn(user_text=request.text))
session_data["turns"] = active_turns
session_data["discarded_future"] = discarded_future
session_data["cursor_turn_id"] = None
session_data["status"] = "pending_assistant_turn"
checkpoint["anchor_turn_id"] = _latest_completed_turn_id(active_turns)
try:
await db_client.update_workflow_run_text_session(
run_id,
session_data=session_data,
checkpoint=checkpoint,
expected_revision=request.expected_revision,
)
except WorkflowRunTextSessionRevisionConflictError as e:
raise HTTPException(status_code=409, detail=_revision_conflict_detail(e))
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
return await _execute_pending_turn_and_build_response(
workflow_id=workflow_id,
run_id=run_id,
text_session=text_session,
user=user,
)
@router.post(
"/{workflow_id}/text-chat/sessions/{run_id}/rewind",
response_model=WorkflowRunTextSessionResponse,
)
async def rewind_text_chat_session(
workflow_id: int,
run_id: int,
request: RewindTextChatSessionRequest,
user: UserModel = Depends(get_user),
) -> WorkflowRunTextSessionResponse:
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
session_data = _normalize_session_data(text_session.session_data)
_validate_turn_cursor(session_data, request.cursor_turn_id)
session_data["cursor_turn_id"] = request.cursor_turn_id
session_data["status"] = "rewound" if request.cursor_turn_id else "idle"
try:
await db_client.update_workflow_run_text_session(
run_id,
session_data=session_data,
expected_revision=request.expected_revision,
)
except WorkflowRunTextSessionRevisionConflictError as e:
raise HTTPException(
status_code=409,
detail={
"message": "Text chat session revision conflict",
"expected_revision": e.expected_revision,
"actual_revision": e.actual_revision,
},
)
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
return _build_response(text_session)