dograh/api/routes/workflow_text_chat.py

480 lines
16 KiB
Python
Raw Normal View History

2026-05-18 13:26:29 +05:30
from datetime import UTC, datetime
from typing import Any, Dict
from uuid import uuid4
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from api.db import db_client
from api.db.models import UserModel, WorkflowRunTextSessionModel
from api.db.workflow_run_text_session_client import (
WorkflowRunTextSessionRevisionConflictError,
)
from api.enums import WorkflowRunMode
from api.services.auth.depends import get_user
2026-05-19 08:24:39 +05:30
from api.services.workflow.text_chat_runner import (
execute_text_chat_pending_turn,
merge_text_chat_usage_info,
)
2026-05-18 13:26:29 +05:30
router = APIRouter(prefix="/workflow", tags=["workflow-text-chat"])
TEXT_CHAT_SESSION_VERSION = 1
TEXT_CHAT_CHECKPOINT_VERSION = 1
class CreateTextChatSessionRequest(BaseModel):
name: str | None = None
initial_context: Dict[str, Any] | None = None
annotations: Dict[str, Any] | None = None
class AppendTextChatMessageRequest(BaseModel):
text: str = Field(min_length=1)
expected_revision: int | None = None
class RewindTextChatSessionRequest(BaseModel):
cursor_turn_id: str | None = None
expected_revision: int | None = None
class WorkflowRunTextSessionResponse(BaseModel):
workflow_run_id: int
workflow_id: int
name: str
mode: str
state: str
is_completed: bool
revision: int
initial_context: Dict[str, Any] | None = None
gathered_context: Dict[str, Any] | None = None
annotations: Dict[str, Any] | None = None
session_data: Dict[str, Any]
checkpoint: Dict[str, Any]
created_at: datetime
updated_at: datetime | None = None
def _default_session_data() -> Dict[str, Any]:
return {
"version": TEXT_CHAT_SESSION_VERSION,
"status": "idle",
"cursor_turn_id": None,
"turns": [],
"discarded_future": [],
"simulator": {
"enabled": False,
"config": {},
},
}
def _default_checkpoint() -> Dict[str, Any]:
return {
"version": TEXT_CHAT_CHECKPOINT_VERSION,
"anchor_turn_id": None,
"current_node_id": None,
"messages": [],
"gathered_context": {},
"tool_state": {},
}
def _normalize_session_data(session_data: Dict[str, Any] | None) -> Dict[str, Any]:
normalized = {
**_default_session_data(),
**(session_data or {}),
}
normalized["turns"] = list(normalized.get("turns") or [])
normalized["discarded_future"] = list(normalized.get("discarded_future") or [])
simulator = normalized.get("simulator") or {}
normalized["simulator"] = {
"enabled": bool(simulator.get("enabled", False)),
"config": dict(simulator.get("config") or {}),
}
return normalized
def _normalize_checkpoint(checkpoint: Dict[str, Any] | None) -> Dict[str, Any]:
normalized = {
**_default_checkpoint(),
**(checkpoint or {}),
}
normalized["messages"] = list(normalized.get("messages") or [])
normalized["gathered_context"] = dict(normalized.get("gathered_context") or {})
normalized["tool_state"] = dict(normalized.get("tool_state") or {})
return normalized
def _get_state_value(state: Any) -> str:
return state.value if hasattr(state, "value") else str(state)
def _build_response(
text_session: WorkflowRunTextSessionModel,
) -> WorkflowRunTextSessionResponse:
workflow_run = text_session.workflow_run
return WorkflowRunTextSessionResponse(
workflow_run_id=workflow_run.id,
workflow_id=workflow_run.workflow_id,
name=workflow_run.name,
mode=workflow_run.mode,
state=_get_state_value(workflow_run.state),
is_completed=workflow_run.is_completed,
revision=text_session.revision,
initial_context=workflow_run.initial_context,
gathered_context=workflow_run.gathered_context,
annotations=workflow_run.annotations,
session_data=_normalize_session_data(text_session.session_data),
checkpoint=_normalize_checkpoint(text_session.checkpoint),
created_at=text_session.created_at,
updated_at=text_session.updated_at,
)
2026-05-18 14:38:51 +05:30
def _validate_turn_cursor(
session_data: Dict[str, Any], cursor_turn_id: str | None
) -> None:
2026-05-18 13:26:29 +05:30
if cursor_turn_id is None:
return
if not any(turn.get("id") == cursor_turn_id for turn in session_data["turns"]):
2026-05-18 14:38:51 +05:30
raise HTTPException(
status_code=404, detail="Turn not found in text chat session"
)
2026-05-18 13:26:29 +05:30
def _truncate_future_turns(
session_data: Dict[str, Any],
) -> tuple[list[Dict[str, Any]], list[Dict[str, Any]]]:
cursor_turn_id = session_data.get("cursor_turn_id")
turns = list(session_data.get("turns") or [])
discarded_future = list(session_data.get("discarded_future") or [])
if cursor_turn_id is None:
return turns, discarded_future
for index, turn in enumerate(turns):
if turn.get("id") == cursor_turn_id:
active_turns = turns[: index + 1]
future_turns = turns[index + 1 :]
if future_turns:
discarded_future.append(
{
"rewound_from_turn_id": cursor_turn_id,
"discarded_at": datetime.now(UTC).isoformat(),
"turns": future_turns,
}
)
return active_turns, discarded_future
raise HTTPException(status_code=404, detail="Turn not found in text chat session")
def _latest_completed_turn_id(turns: list[Dict[str, Any]]) -> str | None:
for turn in reversed(turns):
2026-05-19 08:24:39 +05:30
if turn.get("status") == "completed":
2026-05-18 13:26:29 +05:30
return turn.get("id")
return None
2026-05-19 08:24:39 +05:30
def _build_pending_turn(*, user_text: str | None) -> Dict[str, Any]:
now = datetime.now(UTC).isoformat()
return {
"id": f"turn_{uuid4().hex[:12]}",
"status": "pending",
"created_at": now,
"user_message": (
{
"text": user_text,
"created_at": now,
}
if user_text is not None
else None
),
"assistant_message": None,
"events": [],
"usage": {},
}
def _revision_conflict_detail(
e: WorkflowRunTextSessionRevisionConflictError,
) -> dict[str, Any]:
return {
"message": "Text chat session revision conflict",
"expected_revision": e.expected_revision,
"actual_revision": e.actual_revision,
}
2026-05-18 13:26:29 +05:30
async def _load_text_session_or_404(
workflow_id: int,
run_id: int,
user: UserModel,
) -> WorkflowRunTextSessionModel:
2026-05-19 08:24:39 +05:30
if user.selected_organization_id is None:
raise HTTPException(status_code=403, detail="Organization context is required")
2026-05-18 13:26:29 +05:30
text_session = await db_client.get_workflow_run_text_session(
run_id, organization_id=user.selected_organization_id
)
if not text_session or not text_session.workflow_run:
raise HTTPException(status_code=404, detail="Text chat session not found")
if text_session.workflow_run.workflow_id != workflow_id:
raise HTTPException(status_code=404, detail="Text chat session not found")
if text_session.workflow_run.mode != WorkflowRunMode.TEXTCHAT.value:
2026-05-18 14:38:51 +05:30
raise HTTPException(
status_code=400, detail="Workflow run is not a text chat session"
)
2026-05-18 13:26:29 +05:30
return text_session
2026-05-19 08:24:39 +05:30
async def _execute_pending_turn_and_build_response(
*,
workflow_id: int,
run_id: int,
text_session: WorkflowRunTextSessionModel,
user: UserModel,
) -> WorkflowRunTextSessionResponse:
try:
execution = await execute_text_chat_pending_turn(
workflow_run_id=run_id,
workflow_id=workflow_id,
session_data=_normalize_session_data(text_session.session_data),
checkpoint=_normalize_checkpoint(text_session.checkpoint),
)
except Exception as e:
failed_session_data = _normalize_session_data(text_session.session_data)
failed_turns = list(failed_session_data.get("turns") or [])
if failed_turns and failed_turns[-1].get("status") == "pending":
failed_turns[-1]["status"] = "failed"
failed_turns[-1]["events"] = [
*(failed_turns[-1].get("events") or []),
{
"type": "execution_error",
"created_at": datetime.now(UTC).isoformat(),
"payload": {"message": str(e)},
},
]
failed_session_data["turns"] = failed_turns
failed_session_data["status"] = "error"
try:
await db_client.update_workflow_run_text_session(
run_id,
session_data=failed_session_data,
expected_revision=text_session.revision,
)
except WorkflowRunTextSessionRevisionConflictError:
pass
raise HTTPException(
status_code=500, detail="Failed to execute text chat assistant turn"
)
completed_session_data = _normalize_session_data(text_session.session_data)
completed_turns = list(completed_session_data.get("turns") or [])
if not completed_turns or completed_turns[-1].get("status") != "pending":
raise HTTPException(
status_code=500,
detail="Text chat session lost its pending turn before completion",
)
completed_turns[-1]["status"] = "completed"
completed_turns[-1]["assistant_message"] = (
{
"text": execution.assistant_text,
"created_at": execution.assistant_created_at,
}
if execution.assistant_text
else None
)
completed_turns[-1]["events"] = execution.events
completed_turns[-1]["usage"] = execution.usage
completed_turns[-1]["checkpoint_after_turn"] = execution.checkpoint
completed_session_data["turns"] = completed_turns
completed_session_data["status"] = "idle"
try:
await db_client.update_workflow_run_text_session(
run_id,
session_data=completed_session_data,
checkpoint=execution.checkpoint,
expected_revision=text_session.revision,
)
except WorkflowRunTextSessionRevisionConflictError as e:
raise HTTPException(status_code=409, detail=_revision_conflict_detail(e))
existing_usage_info = text_session.workflow_run.usage_info or {}
merged_usage_info = merge_text_chat_usage_info(
existing_usage_info,
execution.usage,
)
await db_client.update_workflow_run(
run_id,
initial_context=execution.initial_context,
usage_info=merged_usage_info,
gathered_context=execution.gathered_context,
state=execution.state,
is_completed=execution.is_completed,
)
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
return _build_response(text_session)
2026-05-18 13:26:29 +05:30
@router.post(
"/{workflow_id}/text-chat/sessions",
response_model=WorkflowRunTextSessionResponse,
)
async def create_text_chat_session(
workflow_id: int,
request: CreateTextChatSessionRequest,
user: UserModel = Depends(get_user),
) -> WorkflowRunTextSessionResponse:
session_name = request.name or f"WR-TEXT-{uuid4().hex[:6].upper()}"
try:
workflow_run = await db_client.create_workflow_run(
name=session_name,
workflow_id=workflow_id,
mode=WorkflowRunMode.TEXTCHAT.value,
user_id=user.id,
initial_context=request.initial_context,
use_draft=True,
)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
annotations = {
"tester": {
"source": "workflow_editor",
"modality": "text",
}
}
if request.annotations:
annotations = {**annotations, **request.annotations}
workflow_run = await db_client.update_workflow_run(
workflow_run.id,
annotations=annotations,
)
text_session = await db_client.ensure_workflow_run_text_session(
workflow_run.id,
session_data=_default_session_data(),
checkpoint=_default_checkpoint(),
)
2026-05-19 08:24:39 +05:30
session_data = _normalize_session_data(text_session.session_data)
checkpoint = _normalize_checkpoint(text_session.checkpoint)
session_data["turns"] = [_build_pending_turn(user_text=None)]
session_data["status"] = "pending_assistant_turn"
checkpoint["anchor_turn_id"] = _latest_completed_turn_id(session_data["turns"])
try:
await db_client.update_workflow_run_text_session(
workflow_run.id,
session_data=session_data,
checkpoint=checkpoint,
expected_revision=text_session.revision,
)
except WorkflowRunTextSessionRevisionConflictError as e:
raise HTTPException(status_code=409, detail=_revision_conflict_detail(e))
text_session = await _load_text_session_or_404(workflow_id, workflow_run.id, user)
return await _execute_pending_turn_and_build_response(
workflow_id=workflow_id,
run_id=workflow_run.id,
text_session=text_session,
user=user,
)
2026-05-18 13:26:29 +05:30
@router.get(
"/{workflow_id}/text-chat/sessions/{run_id}",
response_model=WorkflowRunTextSessionResponse,
)
async def get_text_chat_session(
workflow_id: int,
run_id: int,
user: UserModel = Depends(get_user),
) -> WorkflowRunTextSessionResponse:
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
return _build_response(text_session)
@router.post(
"/{workflow_id}/text-chat/sessions/{run_id}/messages",
response_model=WorkflowRunTextSessionResponse,
)
async def append_text_chat_message(
workflow_id: int,
run_id: int,
request: AppendTextChatMessageRequest,
user: UserModel = Depends(get_user),
) -> WorkflowRunTextSessionResponse:
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
session_data = _normalize_session_data(text_session.session_data)
checkpoint = _normalize_checkpoint(text_session.checkpoint)
active_turns, discarded_future = _truncate_future_turns(session_data)
2026-05-19 08:24:39 +05:30
active_turns.append(_build_pending_turn(user_text=request.text))
2026-05-18 13:26:29 +05:30
session_data["turns"] = active_turns
session_data["discarded_future"] = discarded_future
session_data["cursor_turn_id"] = None
session_data["status"] = "pending_assistant_turn"
checkpoint["anchor_turn_id"] = _latest_completed_turn_id(active_turns)
try:
2026-05-19 08:24:39 +05:30
await db_client.update_workflow_run_text_session(
2026-05-18 13:26:29 +05:30
run_id,
session_data=session_data,
checkpoint=checkpoint,
expected_revision=request.expected_revision,
)
except WorkflowRunTextSessionRevisionConflictError as e:
2026-05-19 08:24:39 +05:30
raise HTTPException(status_code=409, detail=_revision_conflict_detail(e))
2026-05-18 13:26:29 +05:30
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
2026-05-19 08:24:39 +05:30
return await _execute_pending_turn_and_build_response(
workflow_id=workflow_id,
run_id=run_id,
text_session=text_session,
user=user,
)
2026-05-18 13:26:29 +05:30
@router.post(
"/{workflow_id}/text-chat/sessions/{run_id}/rewind",
response_model=WorkflowRunTextSessionResponse,
)
async def rewind_text_chat_session(
workflow_id: int,
run_id: int,
request: RewindTextChatSessionRequest,
user: UserModel = Depends(get_user),
) -> WorkflowRunTextSessionResponse:
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
session_data = _normalize_session_data(text_session.session_data)
_validate_turn_cursor(session_data, request.cursor_turn_id)
session_data["cursor_turn_id"] = request.cursor_turn_id
session_data["status"] = "rewound" if request.cursor_turn_id else "idle"
try:
await db_client.update_workflow_run_text_session(
run_id,
session_data=session_data,
expected_revision=request.expected_revision,
)
except WorkflowRunTextSessionRevisionConflictError as e:
raise HTTPException(
status_code=409,
detail={
"message": "Text chat session revision conflict",
"expected_revision": e.expected_revision,
"actual_revision": e.actual_revision,
},
)
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
return _build_response(text_session)