mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
feat: add backend foundations
This commit is contained in:
parent
0097974444
commit
e313f2d235
7 changed files with 577 additions and 0 deletions
|
|
@ -0,0 +1,41 @@
|
|||
"""add workflow_run_text_sessions
|
||||
|
||||
Revision ID: 2f638891cbb6
|
||||
Revises: 4c1f1e3e8ef2
|
||||
Create Date: 2026-05-18 12:58:58.573381
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '2f638891cbb6'
|
||||
down_revision: Union[str, None] = '4c1f1e3e8ef2'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('workflow_run_text_sessions',
|
||||
sa.Column('workflow_run_id', sa.Integer(), nullable=False),
|
||||
sa.Column('revision', sa.Integer(), server_default=sa.text('0'), nullable=False),
|
||||
sa.Column('session_data', sa.JSON(), server_default=sa.text("'{}'::json"), nullable=False),
|
||||
sa.Column('checkpoint', sa.JSON(), server_default=sa.text("'{}'::json"), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(['workflow_run_id'], ['workflow_runs.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('workflow_run_id')
|
||||
)
|
||||
op.create_index('ix_workflow_run_text_sessions_updated_at', 'workflow_run_text_sessions', ['updated_at'], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index('ix_workflow_run_text_sessions_updated_at', table_name='workflow_run_text_sessions')
|
||||
op.drop_table('workflow_run_text_sessions')
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -16,12 +16,14 @@ from api.db.webhook_credential_client import WebhookCredentialClient
|
|||
from api.db.workflow_client import WorkflowClient
|
||||
from api.db.workflow_recording_client import WorkflowRecordingClient
|
||||
from api.db.workflow_run_client import WorkflowRunClient
|
||||
from api.db.workflow_run_text_session_client import WorkflowRunTextSessionClient
|
||||
from api.db.workflow_template_client import WorkflowTemplateClient
|
||||
|
||||
|
||||
class DBClient(
|
||||
WorkflowClient,
|
||||
WorkflowRunClient,
|
||||
WorkflowRunTextSessionClient,
|
||||
UserClient,
|
||||
OrganizationClient,
|
||||
OrganizationConfigurationClient,
|
||||
|
|
|
|||
|
|
@ -482,6 +482,12 @@ class WorkflowRunModel(Base):
|
|||
queued_run_id = Column(Integer, ForeignKey("queued_runs.id"), nullable=True)
|
||||
queued_run = relationship("QueuedRunModel", foreign_keys=[queued_run_id])
|
||||
public_access_token = Column(String(36), nullable=True)
|
||||
text_session = relationship(
|
||||
"WorkflowRunTextSessionModel",
|
||||
back_populates="workflow_run",
|
||||
uselist=False,
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (
|
||||
|
|
@ -501,6 +507,45 @@ class WorkflowRunModel(Base):
|
|||
)
|
||||
|
||||
|
||||
class WorkflowRunTextSessionModel(Base):
|
||||
__tablename__ = "workflow_run_text_sessions"
|
||||
|
||||
workflow_run_id = Column(
|
||||
Integer,
|
||||
ForeignKey("workflow_runs.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
workflow_run = relationship("WorkflowRunModel", back_populates="text_session")
|
||||
revision = Column(
|
||||
Integer,
|
||||
nullable=False,
|
||||
default=0,
|
||||
server_default=text("0"),
|
||||
)
|
||||
session_data = Column(
|
||||
JSON,
|
||||
nullable=False,
|
||||
default=dict,
|
||||
server_default=text("'{}'::json"),
|
||||
)
|
||||
checkpoint = Column(
|
||||
JSON,
|
||||
nullable=False,
|
||||
default=dict,
|
||||
server_default=text("'{}'::json"),
|
||||
)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC),
|
||||
onupdate=lambda: datetime.now(UTC),
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_workflow_run_text_sessions_updated_at", "updated_at"),
|
||||
)
|
||||
|
||||
|
||||
class OrganizationUsageCycleModel(Base):
|
||||
"""
|
||||
This model is used to track the usage of Dograh tokens for an organization for a given usage
|
||||
|
|
|
|||
128
api/db/workflow_run_text_session_client.py
Normal file
128
api/db/workflow_run_text_session_client.py
Normal file
|
|
@ -0,0 +1,128 @@
|
|||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from api.db.base_client import BaseDBClient
|
||||
from api.db.models import (
|
||||
WorkflowModel,
|
||||
WorkflowRunModel,
|
||||
WorkflowRunTextSessionModel,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowRunTextSessionRevisionConflictError(Exception):
|
||||
def __init__(self, expected_revision: int, actual_revision: int):
|
||||
self.expected_revision = expected_revision
|
||||
self.actual_revision = actual_revision
|
||||
super().__init__(
|
||||
"Workflow run text session revision conflict: "
|
||||
f"expected {expected_revision}, found {actual_revision}"
|
||||
)
|
||||
|
||||
|
||||
class WorkflowRunTextSessionClient(BaseDBClient):
|
||||
async def ensure_workflow_run_text_session(
|
||||
self,
|
||||
workflow_run_id: int,
|
||||
session_data: dict | None = None,
|
||||
checkpoint: dict | None = None,
|
||||
) -> WorkflowRunTextSessionModel:
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(WorkflowRunTextSessionModel)
|
||||
.where(WorkflowRunTextSessionModel.workflow_run_id == workflow_run_id)
|
||||
.with_for_update()
|
||||
)
|
||||
text_session = result.scalars().first()
|
||||
if text_session:
|
||||
return text_session
|
||||
|
||||
run_result = await session.execute(
|
||||
select(WorkflowRunModel).where(WorkflowRunModel.id == workflow_run_id)
|
||||
)
|
||||
workflow_run = run_result.scalars().first()
|
||||
if not workflow_run:
|
||||
raise ValueError(f"Workflow run with ID {workflow_run_id} not found")
|
||||
|
||||
text_session = WorkflowRunTextSessionModel(
|
||||
workflow_run_id=workflow_run_id,
|
||||
session_data=session_data or {},
|
||||
checkpoint=checkpoint or {},
|
||||
)
|
||||
session.add(text_session)
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
await session.refresh(text_session)
|
||||
return text_session
|
||||
|
||||
async def get_workflow_run_text_session(
|
||||
self,
|
||||
workflow_run_id: int,
|
||||
user_id: int | None = None,
|
||||
organization_id: int | None = None,
|
||||
) -> WorkflowRunTextSessionModel | None:
|
||||
async with self.async_session() as session:
|
||||
query = (
|
||||
select(WorkflowRunTextSessionModel)
|
||||
.options(
|
||||
joinedload(WorkflowRunTextSessionModel.workflow_run).joinedload(
|
||||
WorkflowRunModel.workflow
|
||||
)
|
||||
)
|
||||
.join(WorkflowRunTextSessionModel.workflow_run)
|
||||
.join(WorkflowRunModel.workflow)
|
||||
.where(WorkflowRunTextSessionModel.workflow_run_id == workflow_run_id)
|
||||
)
|
||||
|
||||
if organization_id is not None:
|
||||
query = query.where(WorkflowModel.organization_id == organization_id)
|
||||
elif user_id is not None:
|
||||
query = query.where(WorkflowModel.user_id == user_id)
|
||||
|
||||
result = await session.execute(query)
|
||||
return result.scalars().first()
|
||||
|
||||
async def update_workflow_run_text_session(
|
||||
self,
|
||||
workflow_run_id: int,
|
||||
*,
|
||||
session_data: dict | None = None,
|
||||
checkpoint: dict | None = None,
|
||||
expected_revision: int | None = None,
|
||||
) -> WorkflowRunTextSessionModel:
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(WorkflowRunTextSessionModel)
|
||||
.where(WorkflowRunTextSessionModel.workflow_run_id == workflow_run_id)
|
||||
.with_for_update()
|
||||
)
|
||||
text_session = result.scalars().first()
|
||||
if not text_session:
|
||||
raise ValueError(
|
||||
f"Workflow run text session with run ID {workflow_run_id} not found"
|
||||
)
|
||||
|
||||
if (
|
||||
expected_revision is not None
|
||||
and text_session.revision != expected_revision
|
||||
):
|
||||
raise WorkflowRunTextSessionRevisionConflictError(
|
||||
expected_revision=expected_revision,
|
||||
actual_revision=text_session.revision,
|
||||
)
|
||||
|
||||
if session_data is not None:
|
||||
text_session.session_data = session_data
|
||||
if checkpoint is not None:
|
||||
text_session.checkpoint = checkpoint
|
||||
text_session.revision += 1
|
||||
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
await session.refresh(text_session)
|
||||
return text_session
|
||||
|
|
@ -27,6 +27,7 @@ class WorkflowRunMode(Enum):
|
|||
TELNYX = "telnyx"
|
||||
WEBRTC = "webrtc"
|
||||
SMALLWEBRTC = "smallwebrtc"
|
||||
TEXTCHAT = "textchat"
|
||||
|
||||
# Historical, not used anymore. Don't
|
||||
# use and don't remove
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ from api.routes.webrtc_signaling import router as webrtc_signaling_router
|
|||
from api.routes.workflow import router as workflow_router
|
||||
from api.routes.workflow_embed import router as workflow_embed_router
|
||||
from api.routes.workflow_recording import router as workflow_recording_router
|
||||
from api.routes.workflow_text_chat import router as workflow_text_chat_router
|
||||
|
||||
router = APIRouter(
|
||||
tags=["main"],
|
||||
|
|
@ -35,6 +36,7 @@ router = APIRouter(
|
|||
router.include_router(telephony_router)
|
||||
router.include_router(superuser_router)
|
||||
router.include_router(workflow_router)
|
||||
router.include_router(workflow_text_chat_router)
|
||||
router.include_router(user_router)
|
||||
router.include_router(campaign_router)
|
||||
router.include_router(credentials_router)
|
||||
|
|
|
|||
358
api/routes/workflow_text_chat.py
Normal file
358
api/routes/workflow_text_chat.py
Normal file
|
|
@ -0,0 +1,358 @@
|
|||
from datetime import UTC, datetime
|
||||
from typing import Any, Dict
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel, WorkflowRunTextSessionModel
|
||||
from api.db.workflow_run_text_session_client import (
|
||||
WorkflowRunTextSessionRevisionConflictError,
|
||||
)
|
||||
from api.enums import WorkflowRunMode
|
||||
from api.services.auth.depends import get_user
|
||||
|
||||
router = APIRouter(prefix="/workflow", tags=["workflow-text-chat"])
|
||||
|
||||
TEXT_CHAT_SESSION_VERSION = 1
|
||||
TEXT_CHAT_CHECKPOINT_VERSION = 1
|
||||
|
||||
|
||||
class CreateTextChatSessionRequest(BaseModel):
|
||||
name: str | None = None
|
||||
initial_context: Dict[str, Any] | None = None
|
||||
annotations: Dict[str, Any] | None = None
|
||||
|
||||
|
||||
class AppendTextChatMessageRequest(BaseModel):
|
||||
text: str = Field(min_length=1)
|
||||
expected_revision: int | None = None
|
||||
|
||||
|
||||
class RewindTextChatSessionRequest(BaseModel):
|
||||
cursor_turn_id: str | None = None
|
||||
expected_revision: int | None = None
|
||||
|
||||
|
||||
class WorkflowRunTextSessionResponse(BaseModel):
|
||||
workflow_run_id: int
|
||||
workflow_id: int
|
||||
name: str
|
||||
mode: str
|
||||
state: str
|
||||
is_completed: bool
|
||||
revision: int
|
||||
initial_context: Dict[str, Any] | None = None
|
||||
gathered_context: Dict[str, Any] | None = None
|
||||
annotations: Dict[str, Any] | None = None
|
||||
session_data: Dict[str, Any]
|
||||
checkpoint: Dict[str, Any]
|
||||
created_at: datetime
|
||||
updated_at: datetime | None = None
|
||||
|
||||
|
||||
def _default_session_data() -> Dict[str, Any]:
|
||||
return {
|
||||
"version": TEXT_CHAT_SESSION_VERSION,
|
||||
"status": "idle",
|
||||
"cursor_turn_id": None,
|
||||
"turns": [],
|
||||
"discarded_future": [],
|
||||
"simulator": {
|
||||
"enabled": False,
|
||||
"config": {},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _default_checkpoint() -> Dict[str, Any]:
|
||||
return {
|
||||
"version": TEXT_CHAT_CHECKPOINT_VERSION,
|
||||
"anchor_turn_id": None,
|
||||
"current_node_id": None,
|
||||
"messages": [],
|
||||
"gathered_context": {},
|
||||
"tool_state": {},
|
||||
}
|
||||
|
||||
|
||||
def _normalize_session_data(session_data: Dict[str, Any] | None) -> Dict[str, Any]:
|
||||
normalized = {
|
||||
**_default_session_data(),
|
||||
**(session_data or {}),
|
||||
}
|
||||
normalized["turns"] = list(normalized.get("turns") or [])
|
||||
normalized["discarded_future"] = list(normalized.get("discarded_future") or [])
|
||||
simulator = normalized.get("simulator") or {}
|
||||
normalized["simulator"] = {
|
||||
"enabled": bool(simulator.get("enabled", False)),
|
||||
"config": dict(simulator.get("config") or {}),
|
||||
}
|
||||
return normalized
|
||||
|
||||
|
||||
def _normalize_checkpoint(checkpoint: Dict[str, Any] | None) -> Dict[str, Any]:
|
||||
normalized = {
|
||||
**_default_checkpoint(),
|
||||
**(checkpoint or {}),
|
||||
}
|
||||
normalized["messages"] = list(normalized.get("messages") or [])
|
||||
normalized["gathered_context"] = dict(normalized.get("gathered_context") or {})
|
||||
normalized["tool_state"] = dict(normalized.get("tool_state") or {})
|
||||
return normalized
|
||||
|
||||
|
||||
def _get_state_value(state: Any) -> str:
|
||||
return state.value if hasattr(state, "value") else str(state)
|
||||
|
||||
|
||||
def _build_response(
|
||||
text_session: WorkflowRunTextSessionModel,
|
||||
) -> WorkflowRunTextSessionResponse:
|
||||
workflow_run = text_session.workflow_run
|
||||
return WorkflowRunTextSessionResponse(
|
||||
workflow_run_id=workflow_run.id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
name=workflow_run.name,
|
||||
mode=workflow_run.mode,
|
||||
state=_get_state_value(workflow_run.state),
|
||||
is_completed=workflow_run.is_completed,
|
||||
revision=text_session.revision,
|
||||
initial_context=workflow_run.initial_context,
|
||||
gathered_context=workflow_run.gathered_context,
|
||||
annotations=workflow_run.annotations,
|
||||
session_data=_normalize_session_data(text_session.session_data),
|
||||
checkpoint=_normalize_checkpoint(text_session.checkpoint),
|
||||
created_at=text_session.created_at,
|
||||
updated_at=text_session.updated_at,
|
||||
)
|
||||
|
||||
|
||||
def _build_response_from_run_and_session(workflow_run, text_session):
|
||||
return WorkflowRunTextSessionResponse(
|
||||
workflow_run_id=workflow_run.id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
name=workflow_run.name,
|
||||
mode=workflow_run.mode,
|
||||
state=_get_state_value(workflow_run.state),
|
||||
is_completed=workflow_run.is_completed,
|
||||
revision=text_session.revision,
|
||||
initial_context=workflow_run.initial_context,
|
||||
gathered_context=workflow_run.gathered_context,
|
||||
annotations=workflow_run.annotations,
|
||||
session_data=_normalize_session_data(text_session.session_data),
|
||||
checkpoint=_normalize_checkpoint(text_session.checkpoint),
|
||||
created_at=text_session.created_at,
|
||||
updated_at=text_session.updated_at,
|
||||
)
|
||||
|
||||
|
||||
def _validate_turn_cursor(session_data: Dict[str, Any], cursor_turn_id: str | None) -> None:
|
||||
if cursor_turn_id is None:
|
||||
return
|
||||
if not any(turn.get("id") == cursor_turn_id for turn in session_data["turns"]):
|
||||
raise HTTPException(status_code=404, detail="Turn not found in text chat session")
|
||||
|
||||
|
||||
def _truncate_future_turns(
|
||||
session_data: Dict[str, Any],
|
||||
) -> tuple[list[Dict[str, Any]], list[Dict[str, Any]]]:
|
||||
cursor_turn_id = session_data.get("cursor_turn_id")
|
||||
turns = list(session_data.get("turns") or [])
|
||||
discarded_future = list(session_data.get("discarded_future") or [])
|
||||
|
||||
if cursor_turn_id is None:
|
||||
return turns, discarded_future
|
||||
|
||||
for index, turn in enumerate(turns):
|
||||
if turn.get("id") == cursor_turn_id:
|
||||
active_turns = turns[: index + 1]
|
||||
future_turns = turns[index + 1 :]
|
||||
if future_turns:
|
||||
discarded_future.append(
|
||||
{
|
||||
"rewound_from_turn_id": cursor_turn_id,
|
||||
"discarded_at": datetime.now(UTC).isoformat(),
|
||||
"turns": future_turns,
|
||||
}
|
||||
)
|
||||
return active_turns, discarded_future
|
||||
|
||||
raise HTTPException(status_code=404, detail="Turn not found in text chat session")
|
||||
|
||||
|
||||
def _latest_completed_turn_id(turns: list[Dict[str, Any]]) -> str | None:
|
||||
for turn in reversed(turns):
|
||||
if turn.get("status") == "completed" and turn.get("assistant_message"):
|
||||
return turn.get("id")
|
||||
return None
|
||||
|
||||
|
||||
async def _load_text_session_or_404(
|
||||
workflow_id: int,
|
||||
run_id: int,
|
||||
user: UserModel,
|
||||
) -> WorkflowRunTextSessionModel:
|
||||
text_session = await db_client.get_workflow_run_text_session(
|
||||
run_id, organization_id=user.selected_organization_id
|
||||
)
|
||||
if not text_session or not text_session.workflow_run:
|
||||
raise HTTPException(status_code=404, detail="Text chat session not found")
|
||||
if text_session.workflow_run.workflow_id != workflow_id:
|
||||
raise HTTPException(status_code=404, detail="Text chat session not found")
|
||||
if text_session.workflow_run.mode != WorkflowRunMode.TEXTCHAT.value:
|
||||
raise HTTPException(status_code=400, detail="Workflow run is not a text chat session")
|
||||
return text_session
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{workflow_id}/text-chat/sessions",
|
||||
response_model=WorkflowRunTextSessionResponse,
|
||||
)
|
||||
async def create_text_chat_session(
|
||||
workflow_id: int,
|
||||
request: CreateTextChatSessionRequest,
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> WorkflowRunTextSessionResponse:
|
||||
session_name = request.name or f"WR-TEXT-{uuid4().hex[:6].upper()}"
|
||||
try:
|
||||
workflow_run = await db_client.create_workflow_run(
|
||||
name=session_name,
|
||||
workflow_id=workflow_id,
|
||||
mode=WorkflowRunMode.TEXTCHAT.value,
|
||||
user_id=user.id,
|
||||
initial_context=request.initial_context,
|
||||
use_draft=True,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
annotations = {
|
||||
"tester": {
|
||||
"source": "workflow_editor",
|
||||
"modality": "text",
|
||||
}
|
||||
}
|
||||
if request.annotations:
|
||||
annotations = {**annotations, **request.annotations}
|
||||
workflow_run = await db_client.update_workflow_run(
|
||||
workflow_run.id,
|
||||
annotations=annotations,
|
||||
)
|
||||
|
||||
text_session = await db_client.ensure_workflow_run_text_session(
|
||||
workflow_run.id,
|
||||
session_data=_default_session_data(),
|
||||
checkpoint=_default_checkpoint(),
|
||||
)
|
||||
return _build_response_from_run_and_session(workflow_run, text_session)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{workflow_id}/text-chat/sessions/{run_id}",
|
||||
response_model=WorkflowRunTextSessionResponse,
|
||||
)
|
||||
async def get_text_chat_session(
|
||||
workflow_id: int,
|
||||
run_id: int,
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> WorkflowRunTextSessionResponse:
|
||||
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
|
||||
return _build_response(text_session)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{workflow_id}/text-chat/sessions/{run_id}/messages",
|
||||
response_model=WorkflowRunTextSessionResponse,
|
||||
)
|
||||
async def append_text_chat_message(
|
||||
workflow_id: int,
|
||||
run_id: int,
|
||||
request: AppendTextChatMessageRequest,
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> WorkflowRunTextSessionResponse:
|
||||
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
|
||||
session_data = _normalize_session_data(text_session.session_data)
|
||||
checkpoint = _normalize_checkpoint(text_session.checkpoint)
|
||||
|
||||
active_turns, discarded_future = _truncate_future_turns(session_data)
|
||||
now = datetime.now(UTC).isoformat()
|
||||
turn_id = f"turn_{uuid4().hex[:12]}"
|
||||
active_turns.append(
|
||||
{
|
||||
"id": turn_id,
|
||||
"status": "pending",
|
||||
"created_at": now,
|
||||
"user_message": {
|
||||
"text": request.text,
|
||||
"created_at": now,
|
||||
},
|
||||
"assistant_message": None,
|
||||
"events": [],
|
||||
"usage": {},
|
||||
}
|
||||
)
|
||||
|
||||
session_data["turns"] = active_turns
|
||||
session_data["discarded_future"] = discarded_future
|
||||
session_data["cursor_turn_id"] = None
|
||||
session_data["status"] = "pending_assistant_turn"
|
||||
checkpoint["anchor_turn_id"] = _latest_completed_turn_id(active_turns)
|
||||
|
||||
try:
|
||||
text_session = await db_client.update_workflow_run_text_session(
|
||||
run_id,
|
||||
session_data=session_data,
|
||||
checkpoint=checkpoint,
|
||||
expected_revision=request.expected_revision,
|
||||
)
|
||||
except WorkflowRunTextSessionRevisionConflictError as e:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail={
|
||||
"message": "Text chat session revision conflict",
|
||||
"expected_revision": e.expected_revision,
|
||||
"actual_revision": e.actual_revision,
|
||||
},
|
||||
)
|
||||
|
||||
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
|
||||
return _build_response(text_session)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{workflow_id}/text-chat/sessions/{run_id}/rewind",
|
||||
response_model=WorkflowRunTextSessionResponse,
|
||||
)
|
||||
async def rewind_text_chat_session(
|
||||
workflow_id: int,
|
||||
run_id: int,
|
||||
request: RewindTextChatSessionRequest,
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> WorkflowRunTextSessionResponse:
|
||||
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
|
||||
session_data = _normalize_session_data(text_session.session_data)
|
||||
_validate_turn_cursor(session_data, request.cursor_turn_id)
|
||||
|
||||
session_data["cursor_turn_id"] = request.cursor_turn_id
|
||||
session_data["status"] = "rewound" if request.cursor_turn_id else "idle"
|
||||
|
||||
try:
|
||||
await db_client.update_workflow_run_text_session(
|
||||
run_id,
|
||||
session_data=session_data,
|
||||
expected_revision=request.expected_revision,
|
||||
)
|
||||
except WorkflowRunTextSessionRevisionConflictError as e:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail={
|
||||
"message": "Text chat session revision conflict",
|
||||
"expected_revision": e.expected_revision,
|
||||
"actual_revision": e.actual_revision,
|
||||
},
|
||||
)
|
||||
|
||||
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
|
||||
return _build_response(text_session)
|
||||
Loading…
Add table
Add a link
Reference in a new issue