From e313f2d2353af3604cefde8b621a2cefb354fec9 Mon Sep 17 00:00:00 2001 From: Abhishek Kumar Date: Mon, 18 May 2026 13:26:29 +0530 Subject: [PATCH] feat: add backend foundations --- ...8891cbb6_add_workflow_run_text_sessions.py | 41 ++ api/db/db_client.py | 2 + api/db/models.py | 45 +++ api/db/workflow_run_text_session_client.py | 128 +++++++ api/enums.py | 1 + api/routes/main.py | 2 + api/routes/workflow_text_chat.py | 358 ++++++++++++++++++ 7 files changed, 577 insertions(+) create mode 100644 api/alembic/versions/2f638891cbb6_add_workflow_run_text_sessions.py create mode 100644 api/db/workflow_run_text_session_client.py create mode 100644 api/routes/workflow_text_chat.py diff --git a/api/alembic/versions/2f638891cbb6_add_workflow_run_text_sessions.py b/api/alembic/versions/2f638891cbb6_add_workflow_run_text_sessions.py new file mode 100644 index 0000000..bcf0f14 --- /dev/null +++ b/api/alembic/versions/2f638891cbb6_add_workflow_run_text_sessions.py @@ -0,0 +1,41 @@ +"""add workflow_run_text_sessions + +Revision ID: 2f638891cbb6 +Revises: 4c1f1e3e8ef2 +Create Date: 2026-05-18 12:58:58.573381 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '2f638891cbb6' +down_revision: Union[str, None] = '4c1f1e3e8ef2' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('workflow_run_text_sessions', + sa.Column('workflow_run_id', sa.Integer(), nullable=False), + sa.Column('revision', sa.Integer(), server_default=sa.text('0'), nullable=False), + sa.Column('session_data', sa.JSON(), server_default=sa.text("'{}'::json"), nullable=False), + sa.Column('checkpoint', sa.JSON(), server_default=sa.text("'{}'::json"), nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('updated_at', sa.DateTime(timezone=True), nullable=True), + sa.ForeignKeyConstraint(['workflow_run_id'], ['workflow_runs.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('workflow_run_id') + ) + op.create_index('ix_workflow_run_text_sessions_updated_at', 'workflow_run_text_sessions', ['updated_at'], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index('ix_workflow_run_text_sessions_updated_at', table_name='workflow_run_text_sessions') + op.drop_table('workflow_run_text_sessions') + # ### end Alembic commands ### diff --git a/api/db/db_client.py b/api/db/db_client.py index ef907e7..fa91d34 100644 --- a/api/db/db_client.py +++ b/api/db/db_client.py @@ -16,12 +16,14 @@ from api.db.webhook_credential_client import WebhookCredentialClient from api.db.workflow_client import WorkflowClient from api.db.workflow_recording_client import WorkflowRecordingClient from api.db.workflow_run_client import WorkflowRunClient +from api.db.workflow_run_text_session_client import WorkflowRunTextSessionClient from api.db.workflow_template_client import WorkflowTemplateClient class DBClient( WorkflowClient, WorkflowRunClient, + WorkflowRunTextSessionClient, UserClient, OrganizationClient, OrganizationConfigurationClient, diff --git a/api/db/models.py b/api/db/models.py index c62ca4f..d4843be 100644 --- a/api/db/models.py +++ b/api/db/models.py @@ -482,6 +482,12 @@ class WorkflowRunModel(Base): queued_run_id = Column(Integer, ForeignKey("queued_runs.id"), nullable=True) queued_run = relationship("QueuedRunModel", foreign_keys=[queued_run_id]) public_access_token = Column(String(36), nullable=True) + text_session = relationship( + "WorkflowRunTextSessionModel", + back_populates="workflow_run", + uselist=False, + cascade="all, delete-orphan", + ) # Indexes __table_args__ = ( @@ -501,6 +507,45 @@ class WorkflowRunModel(Base): ) +class WorkflowRunTextSessionModel(Base): + __tablename__ = "workflow_run_text_sessions" + + workflow_run_id = Column( + Integer, + ForeignKey("workflow_runs.id", ondelete="CASCADE"), + primary_key=True, + ) + workflow_run = relationship("WorkflowRunModel", back_populates="text_session") + revision = Column( + Integer, + nullable=False, + default=0, + server_default=text("0"), + ) + session_data = Column( + JSON, + nullable=False, + default=dict, + server_default=text("'{}'::json"), + ) + checkpoint = Column( + JSON, + nullable=False, + default=dict, + server_default=text("'{}'::json"), + ) + created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC)) + updated_at = Column( + DateTime(timezone=True), + default=lambda: datetime.now(UTC), + onupdate=lambda: datetime.now(UTC), + ) + + __table_args__ = ( + Index("ix_workflow_run_text_sessions_updated_at", "updated_at"), + ) + + class OrganizationUsageCycleModel(Base): """ This model is used to track the usage of Dograh tokens for an organization for a given usage diff --git a/api/db/workflow_run_text_session_client.py b/api/db/workflow_run_text_session_client.py new file mode 100644 index 0000000..1412114 --- /dev/null +++ b/api/db/workflow_run_text_session_client.py @@ -0,0 +1,128 @@ +from sqlalchemy.future import select +from sqlalchemy.orm import joinedload + +from api.db.base_client import BaseDBClient +from api.db.models import ( + WorkflowModel, + WorkflowRunModel, + WorkflowRunTextSessionModel, +) + + +class WorkflowRunTextSessionRevisionConflictError(Exception): + def __init__(self, expected_revision: int, actual_revision: int): + self.expected_revision = expected_revision + self.actual_revision = actual_revision + super().__init__( + "Workflow run text session revision conflict: " + f"expected {expected_revision}, found {actual_revision}" + ) + + +class WorkflowRunTextSessionClient(BaseDBClient): + async def ensure_workflow_run_text_session( + self, + workflow_run_id: int, + session_data: dict | None = None, + checkpoint: dict | None = None, + ) -> WorkflowRunTextSessionModel: + async with self.async_session() as session: + result = await session.execute( + select(WorkflowRunTextSessionModel) + .where(WorkflowRunTextSessionModel.workflow_run_id == workflow_run_id) + .with_for_update() + ) + text_session = result.scalars().first() + if text_session: + return text_session + + run_result = await session.execute( + select(WorkflowRunModel).where(WorkflowRunModel.id == workflow_run_id) + ) + workflow_run = run_result.scalars().first() + if not workflow_run: + raise ValueError(f"Workflow run with ID {workflow_run_id} not found") + + text_session = WorkflowRunTextSessionModel( + workflow_run_id=workflow_run_id, + session_data=session_data or {}, + checkpoint=checkpoint or {}, + ) + session.add(text_session) + try: + await session.commit() + except Exception as e: + await session.rollback() + raise e + await session.refresh(text_session) + return text_session + + async def get_workflow_run_text_session( + self, + workflow_run_id: int, + user_id: int | None = None, + organization_id: int | None = None, + ) -> WorkflowRunTextSessionModel | None: + async with self.async_session() as session: + query = ( + select(WorkflowRunTextSessionModel) + .options( + joinedload(WorkflowRunTextSessionModel.workflow_run).joinedload( + WorkflowRunModel.workflow + ) + ) + .join(WorkflowRunTextSessionModel.workflow_run) + .join(WorkflowRunModel.workflow) + .where(WorkflowRunTextSessionModel.workflow_run_id == workflow_run_id) + ) + + if organization_id is not None: + query = query.where(WorkflowModel.organization_id == organization_id) + elif user_id is not None: + query = query.where(WorkflowModel.user_id == user_id) + + result = await session.execute(query) + return result.scalars().first() + + async def update_workflow_run_text_session( + self, + workflow_run_id: int, + *, + session_data: dict | None = None, + checkpoint: dict | None = None, + expected_revision: int | None = None, + ) -> WorkflowRunTextSessionModel: + async with self.async_session() as session: + result = await session.execute( + select(WorkflowRunTextSessionModel) + .where(WorkflowRunTextSessionModel.workflow_run_id == workflow_run_id) + .with_for_update() + ) + text_session = result.scalars().first() + if not text_session: + raise ValueError( + f"Workflow run text session with run ID {workflow_run_id} not found" + ) + + if ( + expected_revision is not None + and text_session.revision != expected_revision + ): + raise WorkflowRunTextSessionRevisionConflictError( + expected_revision=expected_revision, + actual_revision=text_session.revision, + ) + + if session_data is not None: + text_session.session_data = session_data + if checkpoint is not None: + text_session.checkpoint = checkpoint + text_session.revision += 1 + + try: + await session.commit() + except Exception as e: + await session.rollback() + raise e + await session.refresh(text_session) + return text_session diff --git a/api/enums.py b/api/enums.py index b7655b1..8497ddc 100644 --- a/api/enums.py +++ b/api/enums.py @@ -27,6 +27,7 @@ class WorkflowRunMode(Enum): TELNYX = "telnyx" WEBRTC = "webrtc" SMALLWEBRTC = "smallwebrtc" + TEXTCHAT = "textchat" # Historical, not used anymore. Don't # use and don't remove diff --git a/api/routes/main.py b/api/routes/main.py index 6bcd3dc..514b65c 100644 --- a/api/routes/main.py +++ b/api/routes/main.py @@ -26,6 +26,7 @@ from api.routes.webrtc_signaling import router as webrtc_signaling_router from api.routes.workflow import router as workflow_router from api.routes.workflow_embed import router as workflow_embed_router from api.routes.workflow_recording import router as workflow_recording_router +from api.routes.workflow_text_chat import router as workflow_text_chat_router router = APIRouter( tags=["main"], @@ -35,6 +36,7 @@ router = APIRouter( router.include_router(telephony_router) router.include_router(superuser_router) router.include_router(workflow_router) +router.include_router(workflow_text_chat_router) router.include_router(user_router) router.include_router(campaign_router) router.include_router(credentials_router) diff --git a/api/routes/workflow_text_chat.py b/api/routes/workflow_text_chat.py new file mode 100644 index 0000000..3344570 --- /dev/null +++ b/api/routes/workflow_text_chat.py @@ -0,0 +1,358 @@ +from datetime import UTC, datetime +from typing import Any, Dict +from uuid import uuid4 + +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel, Field + +from api.db import db_client +from api.db.models import UserModel, WorkflowRunTextSessionModel +from api.db.workflow_run_text_session_client import ( + WorkflowRunTextSessionRevisionConflictError, +) +from api.enums import WorkflowRunMode +from api.services.auth.depends import get_user + +router = APIRouter(prefix="/workflow", tags=["workflow-text-chat"]) + +TEXT_CHAT_SESSION_VERSION = 1 +TEXT_CHAT_CHECKPOINT_VERSION = 1 + + +class CreateTextChatSessionRequest(BaseModel): + name: str | None = None + initial_context: Dict[str, Any] | None = None + annotations: Dict[str, Any] | None = None + + +class AppendTextChatMessageRequest(BaseModel): + text: str = Field(min_length=1) + expected_revision: int | None = None + + +class RewindTextChatSessionRequest(BaseModel): + cursor_turn_id: str | None = None + expected_revision: int | None = None + + +class WorkflowRunTextSessionResponse(BaseModel): + workflow_run_id: int + workflow_id: int + name: str + mode: str + state: str + is_completed: bool + revision: int + initial_context: Dict[str, Any] | None = None + gathered_context: Dict[str, Any] | None = None + annotations: Dict[str, Any] | None = None + session_data: Dict[str, Any] + checkpoint: Dict[str, Any] + created_at: datetime + updated_at: datetime | None = None + + +def _default_session_data() -> Dict[str, Any]: + return { + "version": TEXT_CHAT_SESSION_VERSION, + "status": "idle", + "cursor_turn_id": None, + "turns": [], + "discarded_future": [], + "simulator": { + "enabled": False, + "config": {}, + }, + } + + +def _default_checkpoint() -> Dict[str, Any]: + return { + "version": TEXT_CHAT_CHECKPOINT_VERSION, + "anchor_turn_id": None, + "current_node_id": None, + "messages": [], + "gathered_context": {}, + "tool_state": {}, + } + + +def _normalize_session_data(session_data: Dict[str, Any] | None) -> Dict[str, Any]: + normalized = { + **_default_session_data(), + **(session_data or {}), + } + normalized["turns"] = list(normalized.get("turns") or []) + normalized["discarded_future"] = list(normalized.get("discarded_future") or []) + simulator = normalized.get("simulator") or {} + normalized["simulator"] = { + "enabled": bool(simulator.get("enabled", False)), + "config": dict(simulator.get("config") or {}), + } + return normalized + + +def _normalize_checkpoint(checkpoint: Dict[str, Any] | None) -> Dict[str, Any]: + normalized = { + **_default_checkpoint(), + **(checkpoint or {}), + } + normalized["messages"] = list(normalized.get("messages") or []) + normalized["gathered_context"] = dict(normalized.get("gathered_context") or {}) + normalized["tool_state"] = dict(normalized.get("tool_state") or {}) + return normalized + + +def _get_state_value(state: Any) -> str: + return state.value if hasattr(state, "value") else str(state) + + +def _build_response( + text_session: WorkflowRunTextSessionModel, +) -> WorkflowRunTextSessionResponse: + workflow_run = text_session.workflow_run + return WorkflowRunTextSessionResponse( + workflow_run_id=workflow_run.id, + workflow_id=workflow_run.workflow_id, + name=workflow_run.name, + mode=workflow_run.mode, + state=_get_state_value(workflow_run.state), + is_completed=workflow_run.is_completed, + revision=text_session.revision, + initial_context=workflow_run.initial_context, + gathered_context=workflow_run.gathered_context, + annotations=workflow_run.annotations, + session_data=_normalize_session_data(text_session.session_data), + checkpoint=_normalize_checkpoint(text_session.checkpoint), + created_at=text_session.created_at, + updated_at=text_session.updated_at, + ) + + +def _build_response_from_run_and_session(workflow_run, text_session): + return WorkflowRunTextSessionResponse( + workflow_run_id=workflow_run.id, + workflow_id=workflow_run.workflow_id, + name=workflow_run.name, + mode=workflow_run.mode, + state=_get_state_value(workflow_run.state), + is_completed=workflow_run.is_completed, + revision=text_session.revision, + initial_context=workflow_run.initial_context, + gathered_context=workflow_run.gathered_context, + annotations=workflow_run.annotations, + session_data=_normalize_session_data(text_session.session_data), + checkpoint=_normalize_checkpoint(text_session.checkpoint), + created_at=text_session.created_at, + updated_at=text_session.updated_at, + ) + + +def _validate_turn_cursor(session_data: Dict[str, Any], cursor_turn_id: str | None) -> None: + if cursor_turn_id is None: + return + if not any(turn.get("id") == cursor_turn_id for turn in session_data["turns"]): + raise HTTPException(status_code=404, detail="Turn not found in text chat session") + + +def _truncate_future_turns( + session_data: Dict[str, Any], +) -> tuple[list[Dict[str, Any]], list[Dict[str, Any]]]: + cursor_turn_id = session_data.get("cursor_turn_id") + turns = list(session_data.get("turns") or []) + discarded_future = list(session_data.get("discarded_future") or []) + + if cursor_turn_id is None: + return turns, discarded_future + + for index, turn in enumerate(turns): + if turn.get("id") == cursor_turn_id: + active_turns = turns[: index + 1] + future_turns = turns[index + 1 :] + if future_turns: + discarded_future.append( + { + "rewound_from_turn_id": cursor_turn_id, + "discarded_at": datetime.now(UTC).isoformat(), + "turns": future_turns, + } + ) + return active_turns, discarded_future + + raise HTTPException(status_code=404, detail="Turn not found in text chat session") + + +def _latest_completed_turn_id(turns: list[Dict[str, Any]]) -> str | None: + for turn in reversed(turns): + if turn.get("status") == "completed" and turn.get("assistant_message"): + return turn.get("id") + return None + + +async def _load_text_session_or_404( + workflow_id: int, + run_id: int, + user: UserModel, +) -> WorkflowRunTextSessionModel: + text_session = await db_client.get_workflow_run_text_session( + run_id, organization_id=user.selected_organization_id + ) + if not text_session or not text_session.workflow_run: + raise HTTPException(status_code=404, detail="Text chat session not found") + if text_session.workflow_run.workflow_id != workflow_id: + raise HTTPException(status_code=404, detail="Text chat session not found") + if text_session.workflow_run.mode != WorkflowRunMode.TEXTCHAT.value: + raise HTTPException(status_code=400, detail="Workflow run is not a text chat session") + return text_session + + +@router.post( + "/{workflow_id}/text-chat/sessions", + response_model=WorkflowRunTextSessionResponse, +) +async def create_text_chat_session( + workflow_id: int, + request: CreateTextChatSessionRequest, + user: UserModel = Depends(get_user), +) -> WorkflowRunTextSessionResponse: + session_name = request.name or f"WR-TEXT-{uuid4().hex[:6].upper()}" + try: + workflow_run = await db_client.create_workflow_run( + name=session_name, + workflow_id=workflow_id, + mode=WorkflowRunMode.TEXTCHAT.value, + user_id=user.id, + initial_context=request.initial_context, + use_draft=True, + ) + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) + + annotations = { + "tester": { + "source": "workflow_editor", + "modality": "text", + } + } + if request.annotations: + annotations = {**annotations, **request.annotations} + workflow_run = await db_client.update_workflow_run( + workflow_run.id, + annotations=annotations, + ) + + text_session = await db_client.ensure_workflow_run_text_session( + workflow_run.id, + session_data=_default_session_data(), + checkpoint=_default_checkpoint(), + ) + return _build_response_from_run_and_session(workflow_run, text_session) + + +@router.get( + "/{workflow_id}/text-chat/sessions/{run_id}", + response_model=WorkflowRunTextSessionResponse, +) +async def get_text_chat_session( + workflow_id: int, + run_id: int, + user: UserModel = Depends(get_user), +) -> WorkflowRunTextSessionResponse: + text_session = await _load_text_session_or_404(workflow_id, run_id, user) + return _build_response(text_session) + + +@router.post( + "/{workflow_id}/text-chat/sessions/{run_id}/messages", + response_model=WorkflowRunTextSessionResponse, +) +async def append_text_chat_message( + workflow_id: int, + run_id: int, + request: AppendTextChatMessageRequest, + user: UserModel = Depends(get_user), +) -> WorkflowRunTextSessionResponse: + text_session = await _load_text_session_or_404(workflow_id, run_id, user) + session_data = _normalize_session_data(text_session.session_data) + checkpoint = _normalize_checkpoint(text_session.checkpoint) + + active_turns, discarded_future = _truncate_future_turns(session_data) + now = datetime.now(UTC).isoformat() + turn_id = f"turn_{uuid4().hex[:12]}" + active_turns.append( + { + "id": turn_id, + "status": "pending", + "created_at": now, + "user_message": { + "text": request.text, + "created_at": now, + }, + "assistant_message": None, + "events": [], + "usage": {}, + } + ) + + session_data["turns"] = active_turns + session_data["discarded_future"] = discarded_future + session_data["cursor_turn_id"] = None + session_data["status"] = "pending_assistant_turn" + checkpoint["anchor_turn_id"] = _latest_completed_turn_id(active_turns) + + try: + text_session = await db_client.update_workflow_run_text_session( + run_id, + session_data=session_data, + checkpoint=checkpoint, + expected_revision=request.expected_revision, + ) + except WorkflowRunTextSessionRevisionConflictError as e: + raise HTTPException( + status_code=409, + detail={ + "message": "Text chat session revision conflict", + "expected_revision": e.expected_revision, + "actual_revision": e.actual_revision, + }, + ) + + text_session = await _load_text_session_or_404(workflow_id, run_id, user) + return _build_response(text_session) + + +@router.post( + "/{workflow_id}/text-chat/sessions/{run_id}/rewind", + response_model=WorkflowRunTextSessionResponse, +) +async def rewind_text_chat_session( + workflow_id: int, + run_id: int, + request: RewindTextChatSessionRequest, + user: UserModel = Depends(get_user), +) -> WorkflowRunTextSessionResponse: + text_session = await _load_text_session_or_404(workflow_id, run_id, user) + session_data = _normalize_session_data(text_session.session_data) + _validate_turn_cursor(session_data, request.cursor_turn_id) + + session_data["cursor_turn_id"] = request.cursor_turn_id + session_data["status"] = "rewound" if request.cursor_turn_id else "idle" + + try: + await db_client.update_workflow_run_text_session( + run_id, + session_data=session_data, + expected_revision=request.expected_revision, + ) + except WorkflowRunTextSessionRevisionConflictError as e: + raise HTTPException( + status_code=409, + detail={ + "message": "Text chat session revision conflict", + "expected_revision": e.expected_revision, + "actual_revision": e.actual_revision, + }, + ) + + text_session = await _load_text_session_or_404(workflow_id, run_id, user) + return _build_response(text_session)