mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-13 08:15:21 +02:00
feat: add backend foundations
This commit is contained in:
parent
0097974444
commit
e313f2d235
7 changed files with 577 additions and 0 deletions
|
|
@ -16,12 +16,14 @@ from api.db.webhook_credential_client import WebhookCredentialClient
|
|||
from api.db.workflow_client import WorkflowClient
|
||||
from api.db.workflow_recording_client import WorkflowRecordingClient
|
||||
from api.db.workflow_run_client import WorkflowRunClient
|
||||
from api.db.workflow_run_text_session_client import WorkflowRunTextSessionClient
|
||||
from api.db.workflow_template_client import WorkflowTemplateClient
|
||||
|
||||
|
||||
class DBClient(
|
||||
WorkflowClient,
|
||||
WorkflowRunClient,
|
||||
WorkflowRunTextSessionClient,
|
||||
UserClient,
|
||||
OrganizationClient,
|
||||
OrganizationConfigurationClient,
|
||||
|
|
|
|||
|
|
@ -482,6 +482,12 @@ class WorkflowRunModel(Base):
|
|||
queued_run_id = Column(Integer, ForeignKey("queued_runs.id"), nullable=True)
|
||||
queued_run = relationship("QueuedRunModel", foreign_keys=[queued_run_id])
|
||||
public_access_token = Column(String(36), nullable=True)
|
||||
text_session = relationship(
|
||||
"WorkflowRunTextSessionModel",
|
||||
back_populates="workflow_run",
|
||||
uselist=False,
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (
|
||||
|
|
@ -501,6 +507,45 @@ class WorkflowRunModel(Base):
|
|||
)
|
||||
|
||||
|
||||
class WorkflowRunTextSessionModel(Base):
|
||||
__tablename__ = "workflow_run_text_sessions"
|
||||
|
||||
workflow_run_id = Column(
|
||||
Integer,
|
||||
ForeignKey("workflow_runs.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
workflow_run = relationship("WorkflowRunModel", back_populates="text_session")
|
||||
revision = Column(
|
||||
Integer,
|
||||
nullable=False,
|
||||
default=0,
|
||||
server_default=text("0"),
|
||||
)
|
||||
session_data = Column(
|
||||
JSON,
|
||||
nullable=False,
|
||||
default=dict,
|
||||
server_default=text("'{}'::json"),
|
||||
)
|
||||
checkpoint = Column(
|
||||
JSON,
|
||||
nullable=False,
|
||||
default=dict,
|
||||
server_default=text("'{}'::json"),
|
||||
)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC),
|
||||
onupdate=lambda: datetime.now(UTC),
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_workflow_run_text_sessions_updated_at", "updated_at"),
|
||||
)
|
||||
|
||||
|
||||
class OrganizationUsageCycleModel(Base):
|
||||
"""
|
||||
This model is used to track the usage of Dograh tokens for an organization for a given usage
|
||||
|
|
|
|||
128
api/db/workflow_run_text_session_client.py
Normal file
128
api/db/workflow_run_text_session_client.py
Normal file
|
|
@ -0,0 +1,128 @@
|
|||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from api.db.base_client import BaseDBClient
|
||||
from api.db.models import (
|
||||
WorkflowModel,
|
||||
WorkflowRunModel,
|
||||
WorkflowRunTextSessionModel,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowRunTextSessionRevisionConflictError(Exception):
|
||||
def __init__(self, expected_revision: int, actual_revision: int):
|
||||
self.expected_revision = expected_revision
|
||||
self.actual_revision = actual_revision
|
||||
super().__init__(
|
||||
"Workflow run text session revision conflict: "
|
||||
f"expected {expected_revision}, found {actual_revision}"
|
||||
)
|
||||
|
||||
|
||||
class WorkflowRunTextSessionClient(BaseDBClient):
|
||||
async def ensure_workflow_run_text_session(
|
||||
self,
|
||||
workflow_run_id: int,
|
||||
session_data: dict | None = None,
|
||||
checkpoint: dict | None = None,
|
||||
) -> WorkflowRunTextSessionModel:
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(WorkflowRunTextSessionModel)
|
||||
.where(WorkflowRunTextSessionModel.workflow_run_id == workflow_run_id)
|
||||
.with_for_update()
|
||||
)
|
||||
text_session = result.scalars().first()
|
||||
if text_session:
|
||||
return text_session
|
||||
|
||||
run_result = await session.execute(
|
||||
select(WorkflowRunModel).where(WorkflowRunModel.id == workflow_run_id)
|
||||
)
|
||||
workflow_run = run_result.scalars().first()
|
||||
if not workflow_run:
|
||||
raise ValueError(f"Workflow run with ID {workflow_run_id} not found")
|
||||
|
||||
text_session = WorkflowRunTextSessionModel(
|
||||
workflow_run_id=workflow_run_id,
|
||||
session_data=session_data or {},
|
||||
checkpoint=checkpoint or {},
|
||||
)
|
||||
session.add(text_session)
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
await session.refresh(text_session)
|
||||
return text_session
|
||||
|
||||
async def get_workflow_run_text_session(
|
||||
self,
|
||||
workflow_run_id: int,
|
||||
user_id: int | None = None,
|
||||
organization_id: int | None = None,
|
||||
) -> WorkflowRunTextSessionModel | None:
|
||||
async with self.async_session() as session:
|
||||
query = (
|
||||
select(WorkflowRunTextSessionModel)
|
||||
.options(
|
||||
joinedload(WorkflowRunTextSessionModel.workflow_run).joinedload(
|
||||
WorkflowRunModel.workflow
|
||||
)
|
||||
)
|
||||
.join(WorkflowRunTextSessionModel.workflow_run)
|
||||
.join(WorkflowRunModel.workflow)
|
||||
.where(WorkflowRunTextSessionModel.workflow_run_id == workflow_run_id)
|
||||
)
|
||||
|
||||
if organization_id is not None:
|
||||
query = query.where(WorkflowModel.organization_id == organization_id)
|
||||
elif user_id is not None:
|
||||
query = query.where(WorkflowModel.user_id == user_id)
|
||||
|
||||
result = await session.execute(query)
|
||||
return result.scalars().first()
|
||||
|
||||
async def update_workflow_run_text_session(
|
||||
self,
|
||||
workflow_run_id: int,
|
||||
*,
|
||||
session_data: dict | None = None,
|
||||
checkpoint: dict | None = None,
|
||||
expected_revision: int | None = None,
|
||||
) -> WorkflowRunTextSessionModel:
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(WorkflowRunTextSessionModel)
|
||||
.where(WorkflowRunTextSessionModel.workflow_run_id == workflow_run_id)
|
||||
.with_for_update()
|
||||
)
|
||||
text_session = result.scalars().first()
|
||||
if not text_session:
|
||||
raise ValueError(
|
||||
f"Workflow run text session with run ID {workflow_run_id} not found"
|
||||
)
|
||||
|
||||
if (
|
||||
expected_revision is not None
|
||||
and text_session.revision != expected_revision
|
||||
):
|
||||
raise WorkflowRunTextSessionRevisionConflictError(
|
||||
expected_revision=expected_revision,
|
||||
actual_revision=text_session.revision,
|
||||
)
|
||||
|
||||
if session_data is not None:
|
||||
text_session.session_data = session_data
|
||||
if checkpoint is not None:
|
||||
text_session.checkpoint = checkpoint
|
||||
text_session.revision += 1
|
||||
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
await session.refresh(text_session)
|
||||
return text_session
|
||||
Loading…
Add table
Add a link
Reference in a new issue