fix: remove cost calculation from dograh codebase

This commit is contained in:
Abhishek Kumar 2026-06-12 13:26:33 +05:30
parent 7d4e2e06a9
commit 8f241b89d2
39 changed files with 1067 additions and 1460 deletions

View file

@ -15,6 +15,7 @@ from api.services.configuration.ai_model_configuration import (
merge_ai_model_configuration_v2_secrets,
migrate_workflow_configuration_model_override_to_v2,
)
from api.services.configuration.check_validity import UserConfigurationValidator
from api.services.configuration.masking import mask_key
from api.services.configuration.registry import (
DeepgramSTTConfiguration,
@ -22,6 +23,8 @@ from api.services.configuration.registry import (
DograhSTTService,
DograhTTSService,
ElevenlabsTTSConfiguration,
GoogleLLMService,
GoogleRealtimeLLMConfiguration,
OpenAIEmbeddingsConfiguration,
OpenAILLMService,
)
@ -93,6 +96,67 @@ def test_byok_v2_rejects_dograh_provider():
)
@pytest.mark.asyncio
async def test_byok_realtime_validator_does_not_require_stt_or_tts():
config = OrganizationAIModelConfigurationV2.model_validate(
{
"mode": "byok",
"byok": {
"mode": "realtime",
"realtime": {
"realtime": {
"provider": "google_realtime",
"api_key": "google-realtime-key",
"model": "gemini-3.1-flash-live-preview",
"voice": "Puck",
"language": "en",
},
"llm": {
"provider": "google",
"api_key": "google-llm-key",
"model": "gemini-2.0-flash",
},
},
},
}
)
effective = compile_ai_model_configuration_v2(config)
assert effective.is_realtime is True
assert effective.stt is None
assert effective.tts is None
assert await UserConfigurationValidator().validate(effective) == {
"status": [{"model": "all", "message": "ok"}]
}
@pytest.mark.asyncio
async def test_pipeline_validator_requires_stt_and_tts_when_not_realtime():
effective = EffectiveAIModelConfiguration(
llm=GoogleLLMService(
provider="google",
api_key="google-llm-key",
model="gemini-2.0-flash",
),
realtime=GoogleRealtimeLLMConfiguration(
provider="google_realtime",
api_key="google-realtime-key",
model="gemini-3.1-flash-live-preview",
voice="Puck",
language="en",
),
is_realtime=False,
)
with pytest.raises(ValueError) as exc_info:
await UserConfigurationValidator().validate(effective)
assert exc_info.value.args[0] == [
{"model": "stt", "message": "API key is missing"},
{"model": "tts", "message": "API key is missing"},
]
def test_masked_dograh_key_is_preserved_when_saving_same_mode():
existing = OrganizationAIModelConfigurationV2(
mode="dograh",

View file

@ -1,31 +0,0 @@
from api.services.pricing.cost_calculator import cost_calculator
def test_cost_calculator():
"""Test function to verify cost calculation works"""
sample_usage = {
"llm": {
"OpenAILLMService#0|||gpt-4.1-mini": {
"prompt_tokens": 45380,
"completion_tokens": 496,
"total_tokens": 45876,
"cache_read_input_tokens": 0,
"cache_creation_input_tokens": 0,
}
},
"tts": {"ElevenLabsTTSService#0|||eleven_flash_v2_5": 2399},
"stt": {"DeepgramSTTService#0|||nova-3-general": 177.21536946296692},
"call_duration_seconds": 179,
}
result = cost_calculator.calculate_total_cost(sample_usage)
assert result["llm_cost"] == 45380 * 0.40 / 1_000_000 + 496 * 1.60 / 1_000_000
assert result["tts_cost"] == 2399 * 0.0256 / 1_000
assert result["stt_cost"] == 177.21536946296692 / 60 * 0.0077
assert (
abs(
result["total"]
- (result["llm_cost"] + result["tts_cost"] + result["stt_cost"])
)
< 1e-10
)

View file

@ -128,3 +128,152 @@ async def test_create_correlation_id_uses_bearer_auth(monkeypatch):
},
)
]
@pytest.mark.asyncio
async def test_get_billing_account_status_uses_hosted_org_auth(monkeypatch):
calls = []
class FakeAsyncClient:
def __init__(self, timeout):
self.timeout = timeout
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def get(self, url, headers):
calls.append(("GET", url, headers))
return _Response(200, {"organization_id": 42, "billing_mode": "v2"})
monkeypatch.setattr(
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
)
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
)
client = MPSServiceKeyClient()
assert await client.get_billing_account_status(organization_id=42) == {
"organization_id": 42,
"billing_mode": "v2",
}
assert calls == [
(
"GET",
f"{client.base_url}/api/v1/billing/accounts/42/status",
{
"Content-Type": "application/json",
"X-Secret-Key": "mps-secret",
"X-Organization-Id": "42",
},
)
]
@pytest.mark.asyncio
async def test_report_platform_usage_uses_hosted_secret_auth(monkeypatch):
calls = []
class FakeAsyncClient:
def __init__(self, timeout):
self.timeout = timeout
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def post(self, url, json, headers):
calls.append(("POST", url, json, headers))
return _Response(200, {"metered": True})
monkeypatch.setattr(
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
)
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
)
client = MPSServiceKeyClient()
assert await client.report_platform_usage(
organization_id=42,
correlation_id="mps-corr-123",
workflow_run_id=123,
metadata={"source": "workflow_run_completion"},
) == {"metered": True}
assert calls == [
(
"POST",
f"{client.base_url}/api/v1/billing/accounts/42/platform-usage",
{
"correlation_id": "mps-corr-123",
"workflow_run_id": 123,
"metadata": {"source": "workflow_run_completion"},
},
{
"Content-Type": "application/json",
"X-Secret-Key": "mps-secret",
"X-Organization-Id": "42",
},
)
]
@pytest.mark.asyncio
async def test_report_platform_usage_sends_duration_without_correlation(monkeypatch):
calls = []
class FakeAsyncClient:
def __init__(self, timeout):
self.timeout = timeout
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def post(self, url, json, headers):
calls.append(("POST", url, json, headers))
return _Response(200, {"metered": True})
monkeypatch.setattr(
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
)
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
)
client = MPSServiceKeyClient()
assert await client.report_platform_usage(
organization_id=42,
duration_seconds=87.0,
workflow_run_id=123,
metadata={"source": "workflow_run_completion"},
) == {"metered": True}
assert calls == [
(
"POST",
f"{client.base_url}/api/v1/billing/accounts/42/platform-usage",
{
"duration_seconds": 87.0,
"workflow_run_id": 123,
"metadata": {"source": "workflow_run_completion"},
},
{
"Content-Type": "application/json",
"X-Secret-Key": "mps-secret",
"X-Organization-Id": "42",
},
)
]

View file

@ -0,0 +1,33 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from api.routes import organization_usage
def test_is_mps_billing_v2_depends_only_on_account_mode():
assert organization_usage._is_mps_billing_v2({"billing_mode": "v2"}) is True
assert organization_usage._is_mps_billing_v2({"billing_mode": "v1"}) is False
assert organization_usage._is_mps_billing_v2({"billing_mode": "shadow"}) is False
assert organization_usage._is_mps_billing_v2(None) is False
@pytest.mark.asyncio
async def test_get_mps_billing_account_status_uses_user_provider_id(monkeypatch):
get_status = AsyncMock(return_value={"billing_mode": "v2"})
monkeypatch.setattr(
organization_usage.mps_service_key_client,
"get_billing_account_status",
get_status,
)
user = SimpleNamespace(provider_id="provider-123")
assert await organization_usage._get_mps_billing_account_status(user, 42) == {
"billing_mode": "v2"
}
get_status.assert_awaited_once_with(
organization_id=42,
created_by="provider-123",
)

View file

@ -1,4 +1,4 @@
from api.services.pricing.run_usage_response import format_public_usage_info
from api.services.workflow.run_usage_response import format_public_usage_info
def test_format_public_usage_info():

View file

@ -0,0 +1,212 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from api.services import workflow_run_billing as workflow_run_billing_mod
from api.services.workflow_run_billing import (
report_completed_workflow_run_platform_usage,
report_workflow_run_platform_usage,
)
def _make_workflow_run():
return SimpleNamespace(
id=123,
workflow_id=456,
is_completed=True,
initial_context={"mps_correlation_id": "mps-corr-123"},
usage_info={"call_duration_seconds": 87},
workflow=SimpleNamespace(
organization_id=42,
user=SimpleNamespace(selected_organization_id=42),
),
)
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_reports_hosted_completion(
monkeypatch,
):
workflow_run = _make_workflow_run()
get_status = AsyncMock(return_value={"billing_mode": "v2"})
report_usage = AsyncMock(return_value={"metered": True})
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"get_billing_account_status",
get_status,
)
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_workflow_run_platform_usage(workflow_run)
report_usage.assert_awaited_once_with(
organization_id=42,
correlation_id="mps-corr-123",
duration_seconds=None,
workflow_run_id=workflow_run.id,
metadata={
"source": "workflow_run_completion",
"workflow_id": workflow_run.workflow_id,
"duration_source": "mps_correlation",
},
)
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_reports_duration_without_correlation(
monkeypatch,
):
workflow_run = _make_workflow_run()
workflow_run.initial_context = {}
get_status = AsyncMock(return_value={"billing_mode": "v2"})
report_usage = AsyncMock(return_value={"metered": True})
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"get_billing_account_status",
get_status,
)
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_workflow_run_platform_usage(workflow_run)
report_usage.assert_awaited_once_with(
organization_id=42,
correlation_id=None,
duration_seconds=87.0,
workflow_run_id=workflow_run.id,
metadata={
"source": "workflow_run_completion",
"workflow_id": workflow_run.workflow_id,
"duration_source": "dograh_usage_info",
},
)
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_skips_non_v2_account(monkeypatch):
workflow_run = _make_workflow_run()
get_status = AsyncMock(return_value={"billing_mode": "v1"})
report_usage = AsyncMock()
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"get_billing_account_status",
get_status,
)
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_workflow_run_platform_usage(workflow_run)
get_status.assert_awaited_once_with(organization_id=42)
report_usage.assert_not_awaited()
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_skips_missing_duration_without_correlation(
monkeypatch,
):
workflow_run = _make_workflow_run()
workflow_run.initial_context = {}
workflow_run.usage_info = {}
get_status = AsyncMock(return_value={"billing_mode": "v2"})
report_usage = AsyncMock()
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"get_billing_account_status",
get_status,
)
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_workflow_run_platform_usage(workflow_run)
get_status.assert_not_awaited()
report_usage.assert_not_awaited()
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_skips_oss(monkeypatch):
workflow_run = _make_workflow_run()
report_usage = AsyncMock()
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "oss")
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_workflow_run_platform_usage(workflow_run)
report_usage.assert_not_awaited()
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_skips_incomplete(monkeypatch):
workflow_run = _make_workflow_run()
workflow_run.is_completed = False
report_usage = AsyncMock()
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_workflow_run_platform_usage(workflow_run)
report_usage.assert_not_awaited()
@pytest.mark.asyncio
async def test_report_completed_workflow_run_platform_usage_loads_run(monkeypatch):
workflow_run = _make_workflow_run()
get_run = AsyncMock(return_value=workflow_run)
get_status = AsyncMock(return_value={"billing_mode": "v2"})
report_usage = AsyncMock(return_value={"metered": True})
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
workflow_run_billing_mod.db_client,
"get_workflow_run_by_id",
get_run,
)
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"get_billing_account_status",
get_status,
)
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_completed_workflow_run_platform_usage(workflow_run.id)
get_run.assert_awaited_once_with(workflow_run.id)
report_usage.assert_awaited_once()

View file

@ -1,181 +0,0 @@
from datetime import UTC, datetime
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from api.services.pricing import workflow_run_cost as workflow_run_cost_mod
from api.services.pricing.workflow_run_cost import (
apply_usage_delta_to_organization,
build_workflow_run_cost_info,
calculate_workflow_run_cost,
)
def _make_workflow_run():
return SimpleNamespace(
id=123,
workflow_id=456,
mode="textchat",
created_at=datetime.now(UTC),
usage_info={
"llm": {},
"tts": {},
"stt": {},
"call_duration_seconds": 7,
},
cost_info={},
workflow=SimpleNamespace(
organization_id=42,
user=SimpleNamespace(selected_organization_id=42),
),
)
@pytest.mark.asyncio
async def test_build_workflow_run_cost_info_does_not_update_org_usage(monkeypatch):
workflow_run = _make_workflow_run()
get_org = AsyncMock(return_value=SimpleNamespace(id=42, price_per_second_usd=1.5))
update_usage = AsyncMock()
monkeypatch.setattr(
workflow_run_cost_mod.db_client, "get_organization_by_id", get_org
)
monkeypatch.setattr(
workflow_run_cost_mod.db_client, "update_usage_after_run", update_usage
)
cost_info = await build_workflow_run_cost_info(workflow_run)
assert cost_info is not None
assert cost_info["call_duration_seconds"] == 7
assert "cost_breakdown" in cost_info
assert "dograh_token_usage" in cost_info
assert cost_info["charge_usd"] == 10.5
update_usage.assert_not_called()
@pytest.mark.asyncio
async def test_calculate_workflow_run_cost_keeps_org_usage_side_effect_in_wrapper(
monkeypatch,
):
workflow_run = _make_workflow_run()
get_org = AsyncMock(return_value=SimpleNamespace(id=42, price_per_second_usd=None))
update_run = AsyncMock()
update_usage = AsyncMock()
monkeypatch.setattr(
workflow_run_cost_mod.db_client,
"get_workflow_run_by_id",
AsyncMock(return_value=workflow_run),
)
monkeypatch.setattr(
workflow_run_cost_mod.db_client, "get_organization_by_id", get_org
)
monkeypatch.setattr(
workflow_run_cost_mod.db_client, "update_workflow_run", update_run
)
monkeypatch.setattr(
workflow_run_cost_mod.db_client, "update_usage_after_run", update_usage
)
await calculate_workflow_run_cost(workflow_run.id)
update_run.assert_awaited_once()
saved_kwargs = update_run.await_args.kwargs
assert saved_kwargs["run_id"] == workflow_run.id
assert "cost_breakdown" in saved_kwargs["cost_info"]
update_usage.assert_awaited_once()
@pytest.mark.asyncio
async def test_apply_usage_delta_to_organization_uses_incremental_costs(
monkeypatch,
):
workflow_run = _make_workflow_run()
workflow_run.cost_info = {"call_id": "preserve-me"}
usage_delta_one = {
"llm": {
"OpenAILLMService#0|||gpt-4.1-mini": {
"prompt_tokens": 1_000,
"completion_tokens": 100,
"total_tokens": 1_100,
"cache_read_input_tokens": 0,
"cache_creation_input_tokens": 0,
}
},
"tts": {},
"stt": {},
"call_duration_seconds": 3,
}
usage_delta_two = {
"llm": {
"OpenAILLMService#0|||gpt-4.1-mini": {
"prompt_tokens": 2_000,
"completion_tokens": 50,
"total_tokens": 2_050,
"cache_read_input_tokens": 0,
"cache_creation_input_tokens": 0,
}
},
"tts": {},
"stt": {},
"call_duration_seconds": 4,
}
merged_usage = {
"llm": {
"OpenAILLMService#0|||gpt-4.1-mini": {
"prompt_tokens": 3_000,
"completion_tokens": 150,
"total_tokens": 3_150,
"cache_read_input_tokens": 0,
"cache_creation_input_tokens": 0,
}
},
"tts": {},
"stt": {},
"call_duration_seconds": 7,
}
get_org = AsyncMock(return_value=SimpleNamespace(id=42, price_per_second_usd=1.5))
update_usage = AsyncMock()
monkeypatch.setattr(
workflow_run_cost_mod.db_client, "get_organization_by_id", get_org
)
monkeypatch.setattr(
workflow_run_cost_mod.db_client, "update_usage_after_run", update_usage
)
first_delta = await apply_usage_delta_to_organization(workflow_run, usage_delta_one)
second_delta = await apply_usage_delta_to_organization(
workflow_run, usage_delta_two
)
total_workflow_run = SimpleNamespace(**workflow_run.__dict__)
total_workflow_run.usage_info = merged_usage
total_cost = await build_workflow_run_cost_info(total_workflow_run)
assert first_delta is not None
assert second_delta is not None
assert total_cost is not None
assert update_usage.await_count == 2
assert update_usage.await_args_list[0].args == (
42,
first_delta["dograh_token_usage"],
3.0,
first_delta["charge_usd"],
)
assert update_usage.await_args_list[1].args == (
42,
second_delta["dograh_token_usage"],
4.0,
second_delta["charge_usd"],
)
assert (
first_delta["dograh_token_usage"] + second_delta["dograh_token_usage"]
) == pytest.approx(total_cost["dograh_token_usage"])
assert (
first_delta["charge_usd"] + second_delta["charge_usd"]
== total_cost["charge_usd"]
)

View file

@ -176,11 +176,7 @@ async def test_text_chat_session_creation_executes_initial_assistant_turn(
assert "Start" in (created["gathered_context"] or {}).get("nodes_visited", [])
workflow_run = await db_session.get_workflow_run_by_id(created["workflow_run_id"])
assert workflow_run is not None
assert workflow_run.cost_info[
"call_duration_seconds"
] == workflow_run.usage_info.get("call_duration_seconds", 0)
assert "cost_breakdown" in workflow_run.cost_info
assert "dograh_token_usage" in workflow_run.cost_info
assert "call_duration_seconds" in workflow_run.usage_info
assert _log_texts(run_payload["logs"], "rtf-bot-text") == [
"Hello from the workflow tester."
]
@ -296,11 +292,7 @@ async def test_text_chat_message_executes_assistant_turn(
assert "Start" in (payload["gathered_context"] or {}).get("nodes_visited", [])
workflow_run = await db_session.get_workflow_run_by_id(created["workflow_run_id"])
assert workflow_run is not None
assert workflow_run.cost_info[
"call_duration_seconds"
] == workflow_run.usage_info.get("call_duration_seconds", 0)
assert "cost_breakdown" in workflow_run.cost_info
assert "dograh_token_usage" in workflow_run.cost_info
assert "call_duration_seconds" in workflow_run.usage_info
assert _log_texts(run_payload["logs"], "rtf-user-transcription") == ["Hi there"]
assert _log_texts(run_payload["logs"], "rtf-bot-text") == [
"Welcome to the workflow tester.",