dograh/api/routes/workflow_text_chat.py
2026-05-21 15:17:14 +05:30

282 lines
9.3 KiB
Python

from datetime import datetime
from typing import Any, Dict
from uuid import uuid4
from fastapi import APIRouter, Depends, HTTPException
from pipecat.utils.run_context import set_current_run_id
from pydantic import BaseModel, Field
from api.db import db_client
from api.db.models import UserModel, WorkflowRunTextSessionModel
from api.enums import WorkflowRunMode
from api.services.auth.depends import get_user
from api.services.quota_service import check_dograh_quota
from api.services.workflow.text_chat_session_service import (
TextChatPendingTurnLostError,
TextChatSessionExecutionError,
TextChatSessionRevisionConflictError,
TextChatTurnNotFoundError,
append_text_chat_user_message,
default_text_chat_checkpoint,
default_text_chat_session_data,
execute_pending_text_chat_turn,
initialize_text_chat_session,
normalize_text_chat_checkpoint,
normalize_text_chat_session_data,
rewind_text_chat_session_state,
)
router = APIRouter(prefix="/workflow", tags=["workflow-text-chat"])
class CreateTextChatSessionRequest(BaseModel):
name: str | None = None
initial_context: Dict[str, Any] | None = None
annotations: Dict[str, Any] | None = None
class AppendTextChatMessageRequest(BaseModel):
text: str = Field(min_length=1)
expected_revision: int | None = None
class RewindTextChatSessionRequest(BaseModel):
cursor_turn_id: str | None = None
expected_revision: int | None = None
class WorkflowRunTextSessionResponse(BaseModel):
workflow_run_id: int
workflow_id: int
name: str
mode: str
state: str
is_completed: bool
revision: int
initial_context: Dict[str, Any] | None = None
gathered_context: Dict[str, Any] | None = None
annotations: Dict[str, Any] | None = None
session_data: Dict[str, Any]
checkpoint: Dict[str, Any]
created_at: datetime
updated_at: datetime | None = None
def _get_state_value(state: Any) -> str:
return state.value if hasattr(state, "value") else str(state)
def _build_response(
text_session: WorkflowRunTextSessionModel,
) -> WorkflowRunTextSessionResponse:
workflow_run = text_session.workflow_run
return WorkflowRunTextSessionResponse(
workflow_run_id=workflow_run.id,
workflow_id=workflow_run.workflow_id,
name=workflow_run.name,
mode=workflow_run.mode,
state=_get_state_value(workflow_run.state),
is_completed=workflow_run.is_completed,
revision=text_session.revision,
initial_context=workflow_run.initial_context,
gathered_context=workflow_run.gathered_context,
annotations=workflow_run.annotations,
session_data=normalize_text_chat_session_data(text_session.session_data),
checkpoint=normalize_text_chat_checkpoint(text_session.checkpoint),
created_at=text_session.created_at,
updated_at=text_session.updated_at,
)
def _revision_conflict_detail(e: Any) -> dict[str, Any]:
return {
"message": "Text chat session revision conflict",
"expected_revision": e.expected_revision,
"actual_revision": e.actual_revision,
}
def _require_selected_organization_id(user: UserModel) -> int:
if user.selected_organization_id is None:
raise HTTPException(status_code=403, detail="Organization context is required")
return user.selected_organization_id
async def _ensure_text_chat_quota(user: UserModel, workflow_id: int) -> None:
quota_result = await check_dograh_quota(user, workflow_id=workflow_id)
if not quota_result.has_quota:
raise HTTPException(status_code=402, detail=quota_result.error_message)
async def _load_text_session_or_404(
workflow_id: int,
run_id: int,
user: UserModel,
) -> WorkflowRunTextSessionModel:
set_current_run_id(run_id)
organization_id = _require_selected_organization_id(user)
text_session = await db_client.get_workflow_run_text_session(
run_id, organization_id=organization_id
)
if not text_session or not text_session.workflow_run:
raise HTTPException(status_code=404, detail="Text chat session not found")
if text_session.workflow_run.workflow_id != workflow_id:
raise HTTPException(status_code=404, detail="Text chat session not found")
if text_session.workflow_run.mode != WorkflowRunMode.TEXTCHAT.value:
raise HTTPException(
status_code=400, detail="Workflow run is not a text chat session"
)
return text_session
async def _execute_pending_turn_response(
*,
workflow_id: int,
run_id: int,
text_session: WorkflowRunTextSessionModel,
) -> WorkflowRunTextSessionResponse:
try:
updated_text_session = await execute_pending_text_chat_turn(
workflow_id=workflow_id,
run_id=run_id,
text_session=text_session,
)
except TextChatSessionRevisionConflictError as e:
raise HTTPException(status_code=409, detail=_revision_conflict_detail(e))
except TextChatPendingTurnLostError as e:
raise HTTPException(status_code=500, detail=str(e))
except TextChatSessionExecutionError as e:
raise HTTPException(status_code=500, detail=str(e))
return _build_response(updated_text_session)
@router.post(
"/{workflow_id}/text-chat/sessions",
response_model=WorkflowRunTextSessionResponse,
)
async def create_text_chat_session(
workflow_id: int,
request: CreateTextChatSessionRequest,
user: UserModel = Depends(get_user),
) -> WorkflowRunTextSessionResponse:
organization_id = _require_selected_organization_id(user)
await _ensure_text_chat_quota(user, workflow_id)
session_name = request.name or f"WR-TEXT-{uuid4().hex[:6].upper()}"
try:
workflow_run = await db_client.create_workflow_run(
name=session_name,
workflow_id=workflow_id,
mode=WorkflowRunMode.TEXTCHAT.value,
user_id=user.id,
initial_context=request.initial_context,
use_draft=True,
organization_id=organization_id,
)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
set_current_run_id(workflow_run.id)
annotations = {
"tester": {
"source": "workflow_editor",
"modality": "text",
}
}
if request.annotations:
annotations = {**annotations, **request.annotations}
workflow_run = await db_client.update_workflow_run(
workflow_run.id,
annotations=annotations,
)
text_session = await db_client.ensure_workflow_run_text_session(
workflow_run.id,
session_data=default_text_chat_session_data(),
checkpoint=default_text_chat_checkpoint(),
)
try:
text_session = await initialize_text_chat_session(
run_id=workflow_run.id,
text_session=text_session,
)
except TextChatSessionRevisionConflictError as e:
raise HTTPException(status_code=409, detail=_revision_conflict_detail(e))
return await _execute_pending_turn_response(
workflow_id=workflow_id,
run_id=workflow_run.id,
text_session=text_session,
)
@router.get(
"/{workflow_id}/text-chat/sessions/{run_id}",
response_model=WorkflowRunTextSessionResponse,
)
async def get_text_chat_session(
workflow_id: int,
run_id: int,
user: UserModel = Depends(get_user),
) -> WorkflowRunTextSessionResponse:
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
return _build_response(text_session)
@router.post(
"/{workflow_id}/text-chat/sessions/{run_id}/messages",
response_model=WorkflowRunTextSessionResponse,
)
async def append_text_chat_message(
workflow_id: int,
run_id: int,
request: AppendTextChatMessageRequest,
user: UserModel = Depends(get_user),
) -> WorkflowRunTextSessionResponse:
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
await _ensure_text_chat_quota(user, workflow_id)
try:
text_session = await append_text_chat_user_message(
run_id=run_id,
text_session=text_session,
user_text=request.text,
expected_revision=request.expected_revision,
)
except TextChatSessionRevisionConflictError as e:
raise HTTPException(status_code=409, detail=_revision_conflict_detail(e))
return await _execute_pending_turn_response(
workflow_id=workflow_id,
run_id=run_id,
text_session=text_session,
)
@router.post(
"/{workflow_id}/text-chat/sessions/{run_id}/rewind",
response_model=WorkflowRunTextSessionResponse,
)
async def rewind_text_chat_session(
workflow_id: int,
run_id: int,
request: RewindTextChatSessionRequest,
user: UserModel = Depends(get_user),
) -> WorkflowRunTextSessionResponse:
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
try:
text_session = await rewind_text_chat_session_state(
run_id=run_id,
text_session=text_session,
cursor_turn_id=request.cursor_turn_id,
expected_revision=request.expected_revision,
)
except TextChatTurnNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
except TextChatSessionRevisionConflictError as e:
raise HTTPException(status_code=409, detail=_revision_conflict_detail(e))
return _build_response(text_session)