fix: fix review comments

This commit is contained in:
Abhishek Kumar 2026-05-21 15:17:14 +05:30
parent dfee942f9a
commit c7e0d06a2b
13 changed files with 477 additions and 253 deletions

View file

@ -151,9 +151,9 @@ class OrganizationUsageClient(BaseDBClient):
async def update_usage_after_run(
self,
organization_id: int,
actual_tokens: int,
duration_seconds: int = 0,
charge_usd: float = None,
actual_tokens: float,
duration_seconds: float = 0,
charge_usd: float | None = None,
) -> None:
"""Update usage after a workflow run completes with actual token count and duration.

View file

@ -32,16 +32,22 @@ class WorkflowRunClient(BaseDBClient):
campaign_id: int = None,
queued_run_id: int = None,
use_draft: bool = False,
organization_id: int | None = None,
) -> WorkflowRunModel:
async with self.async_session() as session:
# Get workflow and user to check organization
workflow = await session.execute(
workflow_query = (
select(WorkflowModel)
.options(joinedload(WorkflowModel.user))
.where(
WorkflowModel.id == workflow_id, WorkflowModel.user_id == user_id
)
)
if organization_id is not None:
workflow_query = workflow_query.where(
WorkflowModel.organization_id == organization_id
)
workflow = await session.execute(workflow_query)
workflow = workflow.scalars().first()
if not workflow:
raise ValueError(f"Workflow with ID {workflow_id} not found")

View file

@ -153,6 +153,7 @@ async def initiate_call(
"telephony_configuration_id": telephony_configuration_id,
},
use_draft=True,
organization_id=user.selected_organization_id,
)
workflow_run_id = workflow_run.id
else:

View file

@ -1081,7 +1081,12 @@ async def create_workflow_run(
user: The user to create the workflow run for
"""
run = await db_client.create_workflow_run(
request.name, workflow_id, request.mode, user.id, use_draft=True
request.name,
workflow_id,
request.mode,
user.id,
use_draft=True,
organization_id=user.selected_organization_id,
)
return {
"id": run.id,

View file

@ -10,6 +10,7 @@ from api.db import db_client
from api.db.models import UserModel, WorkflowRunTextSessionModel
from api.enums import WorkflowRunMode
from api.services.auth.depends import get_user
from api.services.quota_service import check_dograh_quota
from api.services.workflow.text_chat_session_service import (
TextChatPendingTurnLostError,
TextChatSessionExecutionError,
@ -95,16 +96,27 @@ def _revision_conflict_detail(e: Any) -> dict[str, Any]:
}
def _require_selected_organization_id(user: UserModel) -> int:
if user.selected_organization_id is None:
raise HTTPException(status_code=403, detail="Organization context is required")
return user.selected_organization_id
async def _ensure_text_chat_quota(user: UserModel, workflow_id: int) -> None:
quota_result = await check_dograh_quota(user, workflow_id=workflow_id)
if not quota_result.has_quota:
raise HTTPException(status_code=402, detail=quota_result.error_message)
async def _load_text_session_or_404(
workflow_id: int,
run_id: int,
user: UserModel,
) -> WorkflowRunTextSessionModel:
set_current_run_id(run_id)
if user.selected_organization_id is None:
raise HTTPException(status_code=403, detail="Organization context is required")
organization_id = _require_selected_organization_id(user)
text_session = await db_client.get_workflow_run_text_session(
run_id, organization_id=user.selected_organization_id
run_id, organization_id=organization_id
)
if not text_session or not text_session.workflow_run:
raise HTTPException(status_code=404, detail="Text chat session not found")
@ -148,6 +160,9 @@ async def create_text_chat_session(
request: CreateTextChatSessionRequest,
user: UserModel = Depends(get_user),
) -> WorkflowRunTextSessionResponse:
organization_id = _require_selected_organization_id(user)
await _ensure_text_chat_quota(user, workflow_id)
session_name = request.name or f"WR-TEXT-{uuid4().hex[:6].upper()}"
try:
workflow_run = await db_client.create_workflow_run(
@ -157,6 +172,7 @@ async def create_text_chat_session(
user_id=user.id,
initial_context=request.initial_context,
use_draft=True,
organization_id=organization_id,
)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
@ -221,6 +237,8 @@ async def append_text_chat_message(
user: UserModel = Depends(get_user),
) -> WorkflowRunTextSessionResponse:
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
await _ensure_text_chat_quota(user, workflow_id)
try:
text_session = await append_text_chat_user_message(
run_id=run_id,

View file

@ -1,3 +1,5 @@
from decimal import Decimal
from loguru import logger
from api.db import db_client
@ -73,59 +75,80 @@ async def _get_pricing_organization(workflow_run):
return await db_client.get_organization_by_id(organization_id)
async def build_workflow_run_cost_info(workflow_run) -> dict | None:
workflow_usage_info = workflow_run.usage_info
if not workflow_usage_info:
async def _build_usage_cost_snapshot(
usage_info: dict | None,
*,
workflow_run=None,
include_telephony_cost: bool = False,
organization=None,
calculated_at: str | None = None,
) -> dict | None:
if not usage_info:
logger.warning("No usage info available for workflow run")
return None
# Calculate cost breakdown
cost_breakdown = cost_calculator.calculate_total_cost(workflow_usage_info)
cost_breakdown = cost_calculator.calculate_total_cost(usage_info)
# Fetch telephony call cost
try:
telephony_cost = await _fetch_telephony_cost(workflow_run)
if telephony_cost:
telephony_cost_usd = telephony_cost["cost_usd"]
provider_name = telephony_cost["provider_name"]
cost_breakdown["telephony_call"] = telephony_cost_usd
cost_breakdown[f"{provider_name}_call"] = telephony_cost_usd
cost_breakdown["total"] = (
float(cost_breakdown["total"]) + telephony_cost_usd
)
except Exception as e:
logger.error(f"Failed to fetch telephony call cost: {e}")
# Don't fail the whole cost calculation if telephony API fails
if include_telephony_cost and workflow_run is not None:
try:
telephony_cost = await _fetch_telephony_cost(workflow_run)
if telephony_cost:
telephony_cost_usd = telephony_cost["cost_usd"]
provider_name = telephony_cost["provider_name"]
cost_breakdown["telephony_call"] = telephony_cost_usd
cost_breakdown[f"{provider_name}_call"] = telephony_cost_usd
cost_breakdown["total"] = (
float(cost_breakdown["total"]) + telephony_cost_usd
)
except Exception as e:
logger.error(f"Failed to fetch telephony call cost: {e}")
# Don't fail the whole cost calculation if telephony API fails
# Convert USD to Dograh Tokens (1 cent = 1 token)
dograh_tokens = round(float(cost_breakdown["total"]) * 100, 2)
total_cost_usd = Decimal(str(cost_breakdown["total"]))
dograh_tokens = float(total_cost_usd * Decimal("100"))
if organization is None and workflow_run is not None:
organization = await _get_pricing_organization(workflow_run)
# Get organization to check if it has USD pricing
org = await _get_pricing_organization(workflow_run)
charge_usd = None
# Calculate USD cost if organization has pricing configured
if org and org.price_per_second_usd:
duration_seconds = workflow_usage_info.get("call_duration_seconds", 0)
charge_usd = duration_seconds * org.price_per_second_usd
if organization and organization.price_per_second_usd:
duration_seconds = usage_info.get("call_duration_seconds", 0)
charge_usd = float(
Decimal(str(duration_seconds))
* Decimal(str(organization.price_per_second_usd))
)
cost_info = {
**(workflow_run.cost_info or {}),
"cost_breakdown": cost_breakdown,
"total_cost_usd": float(cost_breakdown["total"]),
"total_cost_usd": float(total_cost_usd),
"dograh_token_usage": dograh_tokens,
"calculated_at": workflow_run.created_at.isoformat(),
"call_duration_seconds": workflow_usage_info.get("call_duration_seconds", 0),
"calculated_at": calculated_at
or (workflow_run.created_at.isoformat() if workflow_run is not None else None),
"call_duration_seconds": usage_info.get("call_duration_seconds", 0),
}
# Add USD cost if available
if charge_usd is not None:
cost_info["charge_usd"] = charge_usd
cost_info["price_per_second_usd"] = org.price_per_second_usd
cost_info["price_per_second_usd"] = organization.price_per_second_usd
return cost_info
async def build_workflow_run_cost_info(workflow_run) -> dict | None:
cost_info = await _build_usage_cost_snapshot(
workflow_run.usage_info,
workflow_run=workflow_run,
include_telephony_cost=True,
calculated_at=workflow_run.created_at.isoformat(),
)
if cost_info is None:
return None
return {
**(workflow_run.cost_info or {}),
**cost_info,
}
async def save_workflow_run_cost_info(
workflow_run_id: int, cost_info: dict | None
) -> None:
@ -152,6 +175,26 @@ async def apply_workflow_run_usage_to_organization(
)
async def apply_usage_delta_to_organization(
workflow_run, usage_info: dict | None
) -> dict | None:
org = await _get_pricing_organization(workflow_run)
if not org:
return None
cost_info = await _build_usage_cost_snapshot(usage_info, organization=org)
if cost_info is None:
return None
await _update_organization_usage(
org,
float(cost_info.get("dograh_token_usage") or 0),
float(cost_info.get("call_duration_seconds") or 0),
cost_info.get("charge_usd"),
)
return cost_info
async def calculate_workflow_run_cost(workflow_run_id: int):
logger.debug("Calculating cost for workflow run")

View file

@ -4,12 +4,17 @@ from datetime import UTC, datetime
from typing import Any
from uuid import uuid4
from loguru import logger
from api.db import db_client
from api.db.models import WorkflowRunTextSessionModel
from api.db.workflow_run_text_session_client import (
WorkflowRunTextSessionRevisionConflictError,
)
from api.services.pricing.workflow_run_cost import build_workflow_run_cost_info
from api.services.pricing.workflow_run_cost import (
apply_usage_delta_to_organization,
build_workflow_run_cost_info,
)
from api.services.workflow.text_chat_logs import (
build_text_chat_realtime_feedback_events,
)
@ -258,6 +263,15 @@ async def execute_pending_text_chat_turn(
)
workflow_run = await db_client.get_workflow_run_by_id(run_id)
if workflow_run:
try:
# Apply the per-turn delta so org usage tracks cumulative run cost
# without replaying the full session totals on every turn.
await apply_usage_delta_to_organization(workflow_run, execution.usage)
except Exception as e:
logger.error(
f"Failed to update organization usage for text chat run {run_id}: {e}"
)
cost_info = await build_workflow_run_cost_info(workflow_run)
if cost_info is not None:
await db_client.update_workflow_run(run_id, cost_info=cost_info)

View file

@ -419,8 +419,9 @@ class TestStartGreeting:
"""When a node has no greeting, the engine should queue initial LLM generation."""
dto = ReactFlowDTO(
nodes=[
StartCallRFNode(
RFNodeDTO(
id="start",
type="startCall",
position=Position(x=0, y=0),
data=StartCallNodeData(
name="Start",
@ -430,8 +431,9 @@ class TestStartGreeting:
extraction_enabled=False,
),
),
EndCallRFNode(
RFNodeDTO(
id="end",
type="endCall",
position=Position(x=0, y=200),
data=EndCallNodeData(
name="End",

View file

@ -6,6 +6,7 @@ import pytest
from api.services.pricing import workflow_run_cost as workflow_run_cost_mod
from api.services.pricing.workflow_run_cost import (
apply_usage_delta_to_organization,
build_workflow_run_cost_info,
calculate_workflow_run_cost,
)
@ -85,3 +86,96 @@ async def test_calculate_workflow_run_cost_keeps_org_usage_side_effect_in_wrappe
assert saved_kwargs["run_id"] == workflow_run.id
assert "cost_breakdown" in saved_kwargs["cost_info"]
update_usage.assert_awaited_once()
@pytest.mark.asyncio
async def test_apply_usage_delta_to_organization_uses_incremental_costs(
monkeypatch,
):
workflow_run = _make_workflow_run()
workflow_run.cost_info = {"call_id": "preserve-me"}
usage_delta_one = {
"llm": {
"OpenAILLMService#0|||gpt-4.1-mini": {
"prompt_tokens": 1_000,
"completion_tokens": 100,
"total_tokens": 1_100,
"cache_read_input_tokens": 0,
"cache_creation_input_tokens": 0,
}
},
"tts": {},
"stt": {},
"call_duration_seconds": 3,
}
usage_delta_two = {
"llm": {
"OpenAILLMService#0|||gpt-4.1-mini": {
"prompt_tokens": 2_000,
"completion_tokens": 50,
"total_tokens": 2_050,
"cache_read_input_tokens": 0,
"cache_creation_input_tokens": 0,
}
},
"tts": {},
"stt": {},
"call_duration_seconds": 4,
}
merged_usage = {
"llm": {
"OpenAILLMService#0|||gpt-4.1-mini": {
"prompt_tokens": 3_000,
"completion_tokens": 150,
"total_tokens": 3_150,
"cache_read_input_tokens": 0,
"cache_creation_input_tokens": 0,
}
},
"tts": {},
"stt": {},
"call_duration_seconds": 7,
}
get_org = AsyncMock(return_value=SimpleNamespace(id=42, price_per_second_usd=1.5))
update_usage = AsyncMock()
monkeypatch.setattr(
workflow_run_cost_mod.db_client, "get_organization_by_id", get_org
)
monkeypatch.setattr(
workflow_run_cost_mod.db_client, "update_usage_after_run", update_usage
)
first_delta = await apply_usage_delta_to_organization(workflow_run, usage_delta_one)
second_delta = await apply_usage_delta_to_organization(
workflow_run, usage_delta_two
)
total_workflow_run = SimpleNamespace(**workflow_run.__dict__)
total_workflow_run.usage_info = merged_usage
total_cost = await build_workflow_run_cost_info(total_workflow_run)
assert first_delta is not None
assert second_delta is not None
assert total_cost is not None
assert update_usage.await_count == 2
assert update_usage.await_args_list[0].args == (
42,
first_delta["dograh_token_usage"],
3.0,
first_delta["charge_usd"],
)
assert update_usage.await_args_list[1].args == (
42,
second_delta["dograh_token_usage"],
4.0,
second_delta["charge_usd"],
)
assert (
first_delta["dograh_token_usage"] + second_delta["dograh_token_usage"]
) == pytest.approx(total_cost["dograh_token_usage"])
assert (
first_delta["charge_usd"] + second_delta["charge_usd"]
== total_cost["charge_usd"]
)

View file

@ -1,3 +1,4 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
import pytest
@ -968,3 +969,226 @@ async def test_text_chat_session_is_not_accessible_from_another_org(
f"/api/v1/workflow/{workflow.id}/text-chat/sessions/{created['workflow_run_id']}"
)
assert get_response.status_code == 404
@pytest.mark.asyncio
async def test_text_chat_session_creation_requires_selected_org_scope(
db_session,
async_session,
test_client_factory,
):
workflow_definition = {
"nodes": [
{
"id": "start",
"type": "startCall",
"position": {"x": 0, "y": 0},
"data": {
"name": "Start",
"prompt": "You are a helpful assistant.",
"is_start": True,
"allow_interrupt": False,
"add_global_prompt": False,
},
}
],
"edges": [],
}
org_a = OrganizationModel(provider_id="textchat-scope-a")
org_b = OrganizationModel(provider_id="textchat-scope-b")
async_session.add_all([org_a, org_b])
await async_session.flush()
user = UserModel(
provider_id="textchat-scope-user",
selected_organization_id=org_a.id,
)
async_session.add(user)
await async_session.flush()
await db_session.update_user_configuration(
user_id=user.id,
configuration=UserConfiguration.model_validate(USER_CONFIGURATION),
)
workflow = await db_session.create_workflow(
name="Cross-org workflow",
workflow_definition=workflow_definition,
user_id=user.id,
organization_id=org_b.id,
)
llm = MockLLMService(
mock_steps=[MockLLMService.create_text_chunks("Should never run.")],
chunk_delay=0.001,
)
async with test_client_factory(user) as client:
with (
patch(
"api.services.workflow.text_chat_runner.create_llm_service",
return_value=llm,
),
patch(
"api.services.workflow.text_chat_runner.db_client.has_active_recordings",
new=AsyncMock(return_value=False),
),
):
create_response = await client.post(
f"/api/v1/workflow/{workflow.id}/text-chat/sessions",
json={},
)
assert create_response.status_code == 404
_, total_count = await db_session.get_workflow_runs_by_workflow_id(
workflow.id,
organization_id=org_b.id,
)
assert total_count == 0
@pytest.mark.asyncio
async def test_text_chat_session_creation_rejects_quota_before_creating_run(
db_session,
async_session,
test_client_factory,
):
workflow_definition = {
"nodes": [
{
"id": "start",
"type": "startCall",
"position": {"x": 0, "y": 0},
"data": {
"name": "Start",
"prompt": "You are a helpful assistant.",
"is_start": True,
"allow_interrupt": False,
"add_global_prompt": False,
},
}
],
"edges": [],
}
user, workflow = await _create_user_and_workflow(
db_session,
async_session,
workflow_definition=workflow_definition,
suffix="quota-create",
)
async with test_client_factory(user) as client:
with patch(
"api.routes.workflow_text_chat.check_dograh_quota",
new=AsyncMock(
return_value=SimpleNamespace(
has_quota=False,
error_message="Quota exceeded",
)
),
):
create_response = await client.post(
f"/api/v1/workflow/{workflow.id}/text-chat/sessions",
json={},
)
assert create_response.status_code == 402
assert create_response.json()["detail"] == "Quota exceeded"
_, total_count = await db_session.get_workflow_runs_by_workflow_id(
workflow.id,
organization_id=workflow.organization_id,
)
assert total_count == 0
@pytest.mark.asyncio
async def test_text_chat_append_rejects_quota_without_mutating_session(
db_session,
async_session,
test_client_factory,
):
workflow_definition = {
"nodes": [
{
"id": "start",
"type": "startCall",
"position": {"x": 0, "y": 0},
"data": {
"name": "Start",
"prompt": "You are a helpful assistant.",
"is_start": True,
"allow_interrupt": False,
"add_global_prompt": False,
},
}
],
"edges": [],
}
user, workflow = await _create_user_and_workflow(
db_session,
async_session,
workflow_definition=workflow_definition,
suffix="quota-append",
)
llm = MockLLMService(
mock_steps=[
MockLLMService.create_text_chunks("Hello from the workflow tester.")
],
chunk_delay=0.001,
)
async with test_client_factory(user) as client:
with (
patch(
"api.routes.workflow_text_chat.check_dograh_quota",
new=AsyncMock(
side_effect=[
SimpleNamespace(has_quota=True, error_message=""),
SimpleNamespace(
has_quota=False,
error_message="Quota exceeded on append",
),
]
),
),
patch(
"api.services.workflow.text_chat_runner.create_llm_service",
return_value=llm,
),
patch(
"api.services.workflow.text_chat_runner.db_client.has_active_recordings",
new=AsyncMock(return_value=False),
),
):
create_response = await client.post(
f"/api/v1/workflow/{workflow.id}/text-chat/sessions",
json={},
)
assert create_response.status_code == 200
created = create_response.json()
append_response = await client.post(
f"/api/v1/workflow/{workflow.id}/text-chat/sessions/{created['workflow_run_id']}/messages",
json={
"text": "This should be rejected",
"expected_revision": created["revision"],
},
)
assert append_response.status_code == 402
session_response = await client.get(
f"/api/v1/workflow/{workflow.id}/text-chat/sessions/{created['workflow_run_id']}"
)
assert session_response.status_code == 200
session_payload = session_response.json()
assert append_response.json()["detail"] == "Quota exceeded on append"
assert session_payload["revision"] == created["revision"]
assert session_payload["session_data"]["turns"] == created["session_data"]["turns"]
assert (
session_payload["session_data"]["status"] == created["session_data"]["status"]
)