feat: billing and credit management v2 (#429)

* feat: use mps generated correlation ID

* chore: update pipecat submodule

* feat: add credit purchase URL

* feat: carve out billing page and show credit ledger

* feat: deprecate dograh based quota tracking

* fix: remove cost calculation from dograh codebase

* fix: create mps account on migrate to v2

* chore: update pipecat
This commit is contained in:
Abhishek 2026-06-12 14:55:30 +05:30 committed by GitHub
parent 97d7103480
commit 1f1149f4d5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
80 changed files with 3335 additions and 2057 deletions

View file

@ -203,7 +203,7 @@ async def create_workflow_run_rows(
Returns:
Tuple of (workflow_run, user, workflow).
"""
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
org = OrganizationModel(provider_id=f"test-org-{provider_id_suffix}")
async_session.add(org)

View file

@ -1,12 +1,16 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from pydantic import ValidationError
from api.schemas.ai_model_configuration import (
DograhManagedAIModelConfiguration,
EffectiveAIModelConfiguration,
OrganizationAIModelConfigurationResponse,
OrganizationAIModelConfigurationV2,
compile_ai_model_configuration_v2,
)
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.services.configuration.ai_model_configuration import (
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY,
check_for_masked_keys_in_ai_model_configuration_v2,
@ -15,6 +19,7 @@ from api.services.configuration.ai_model_configuration import (
merge_ai_model_configuration_v2_secrets,
migrate_workflow_configuration_model_override_to_v2,
)
from api.services.configuration.check_validity import UserConfigurationValidator
from api.services.configuration.masking import mask_key
from api.services.configuration.registry import (
DeepgramSTTConfiguration,
@ -22,6 +27,8 @@ from api.services.configuration.registry import (
DograhSTTService,
DograhTTSService,
ElevenlabsTTSConfiguration,
GoogleLLMService,
GoogleRealtimeLLMConfiguration,
OpenAIEmbeddingsConfiguration,
OpenAILLMService,
)
@ -49,6 +56,7 @@ def test_dograh_v2_compiles_to_effective_managed_pipeline_with_embeddings():
assert effective.stt.language == "multi"
assert effective.embeddings.provider == "dograh"
assert effective.embeddings.model == "default"
assert effective.managed_service_version == 2
def test_dograh_v2_rejects_non_predefined_speed():
@ -92,6 +100,67 @@ def test_byok_v2_rejects_dograh_provider():
)
@pytest.mark.asyncio
async def test_byok_realtime_validator_does_not_require_stt_or_tts():
config = OrganizationAIModelConfigurationV2.model_validate(
{
"mode": "byok",
"byok": {
"mode": "realtime",
"realtime": {
"realtime": {
"provider": "google_realtime",
"api_key": "google-realtime-key",
"model": "gemini-3.1-flash-live-preview",
"voice": "Puck",
"language": "en",
},
"llm": {
"provider": "google",
"api_key": "google-llm-key",
"model": "gemini-2.0-flash",
},
},
},
}
)
effective = compile_ai_model_configuration_v2(config)
assert effective.is_realtime is True
assert effective.stt is None
assert effective.tts is None
assert await UserConfigurationValidator().validate(effective) == {
"status": [{"model": "all", "message": "ok"}]
}
@pytest.mark.asyncio
async def test_pipeline_validator_requires_stt_and_tts_when_not_realtime():
effective = EffectiveAIModelConfiguration(
llm=GoogleLLMService(
provider="google",
api_key="google-llm-key",
model="gemini-2.0-flash",
),
realtime=GoogleRealtimeLLMConfiguration(
provider="google_realtime",
api_key="google-realtime-key",
model="gemini-3.1-flash-live-preview",
voice="Puck",
language="en",
),
is_realtime=False,
)
with pytest.raises(ValueError) as exc_info:
await UserConfigurationValidator().validate(effective)
assert exc_info.value.args[0] == [
{"model": "stt", "message": "API key is missing"},
{"model": "tts", "message": "API key is missing"},
]
def test_masked_dograh_key_is_preserved_when_saving_same_mode():
existing = OrganizationAIModelConfigurationV2(
mode="dograh",
@ -293,3 +362,98 @@ def test_workflow_model_override_migration_removes_invalid_v1_override_marker():
assert changed is True
assert "model_overrides" not in migrated
assert migrated["ambient_noise_configuration"] == {"enabled": False}
@pytest.mark.asyncio
async def test_migrate_model_configuration_v2_initializes_hosted_mps_billing(
monkeypatch,
):
from api.routes import organization as organization_routes
legacy = EffectiveAIModelConfiguration(
llm=DograhLLMService(
provider="dograh",
api_key=["mps-secret"],
model="default",
),
tts=DograhTTSService(
provider="dograh",
api_key=["mps-secret"],
model="default",
voice="default",
),
stt=DograhSTTService(
provider="dograh",
api_key=["mps-secret"],
model="default",
),
)
expected_response = OrganizationAIModelConfigurationResponse(
configuration={"version": 2, "mode": "dograh"},
effective_configuration={},
source="organization_v2",
)
class FakeValidator:
async def validate(self, *args, **kwargs):
return {"status": [{"model": "all", "message": "ok"}]}
ensure_billing = AsyncMock(return_value={"billing_mode": "v2"})
upsert = AsyncMock()
migrate_workflows = AsyncMock()
monkeypatch.setattr(organization_routes, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
organization_routes,
"get_organization_ai_model_configuration_v2",
AsyncMock(return_value=None),
)
monkeypatch.setattr(
organization_routes.db_client,
"get_user_configurations",
AsyncMock(return_value=legacy),
)
monkeypatch.setattr(
organization_routes,
"UserConfigurationValidator",
lambda: FakeValidator(),
)
monkeypatch.setattr(
organization_routes,
"ensure_hosted_mps_billing_account_v2",
ensure_billing,
)
monkeypatch.setattr(
organization_routes,
"upsert_organization_ai_model_configuration_v2",
upsert,
)
monkeypatch.setattr(
organization_routes,
"migrate_workflow_model_configurations_to_v2",
migrate_workflows,
)
monkeypatch.setattr(
organization_routes,
"_model_configuration_v2_response",
AsyncMock(return_value=expected_response),
)
user = SimpleNamespace(
id=7,
provider_id="provider-123",
selected_organization_id=42,
)
response = await organization_routes.migrate_model_configuration_v2(
force=False,
user=user,
)
ensure_billing.assert_awaited_once_with(42, created_by="provider-123")
upsert.assert_awaited_once()
migrate_workflows.assert_awaited_once_with(
organization_id=42,
fallback_user_config=legacy,
)
assert response == expected_response

View file

@ -0,0 +1,68 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from api.services.auth import depends as auth_depends
@pytest.mark.asyncio
async def test_get_user_initializes_hosted_mps_billing_for_new_org(monkeypatch):
stack_user = {
"id": "stack-user-1",
"selected_team_id": "team-1",
"primary_email_verified": False,
}
user = SimpleNamespace(
id=7,
email=None,
provider_id="stack-user-1",
selected_organization_id=None,
)
organization = SimpleNamespace(id=42)
existing_config = SimpleNamespace(llm=object(), tts=None, stt=None)
ensure_billing = AsyncMock(return_value={"billing_mode": "v2"})
monkeypatch.setattr(auth_depends, "AUTH_PROVIDER", "stack")
monkeypatch.setattr(
auth_depends.stackauth,
"get_user",
AsyncMock(return_value=stack_user),
)
monkeypatch.setattr(
auth_depends.db_client,
"get_or_create_user_by_provider_id",
AsyncMock(return_value=(user, False)),
)
monkeypatch.setattr(
auth_depends.db_client,
"get_or_create_organization_by_provider_id",
AsyncMock(return_value=(organization, True)),
)
monkeypatch.setattr(
auth_depends.db_client,
"add_user_to_organization",
AsyncMock(),
)
monkeypatch.setattr(
auth_depends.db_client,
"update_user_selected_organization",
AsyncMock(),
)
monkeypatch.setattr(
auth_depends.db_client,
"get_user_configurations",
AsyncMock(return_value=existing_config),
)
monkeypatch.setattr(
auth_depends,
"ensure_hosted_mps_billing_account_v2",
ensure_billing,
)
result = await auth_depends.get_user(authorization="Bearer token")
assert result is user
assert result.selected_organization_id == 42
ensure_billing.assert_awaited_once_with(42, created_by="stack-user-1")

View file

@ -1,31 +0,0 @@
from api.services.pricing.cost_calculator import cost_calculator
def test_cost_calculator():
"""Test function to verify cost calculation works"""
sample_usage = {
"llm": {
"OpenAILLMService#0|||gpt-4.1-mini": {
"prompt_tokens": 45380,
"completion_tokens": 496,
"total_tokens": 45876,
"cache_read_input_tokens": 0,
"cache_creation_input_tokens": 0,
}
},
"tts": {"ElevenLabsTTSService#0|||eleven_flash_v2_5": 2399},
"stt": {"DeepgramSTTService#0|||nova-3-general": 177.21536946296692},
"call_duration_seconds": 179,
}
result = cost_calculator.calculate_total_cost(sample_usage)
assert result["llm_cost"] == 45380 * 0.40 / 1_000_000 + 496 * 1.60 / 1_000_000
assert result["tts_cost"] == 2399 * 0.0256 / 1_000
assert result["stt_cost"] == 177.21536946296692 / 60 * 0.0077
assert (
abs(
result["total"]
- (result["llm_cost"] + result["tts_cost"] + result["stt_cost"])
)
< 1e-10
)

View file

@ -0,0 +1,110 @@
import json
import pytest
from openai._types import NOT_GIVEN as OPENAI_NOT_GIVEN
from pipecat.frames.frames import TTSStartedFrame
from pipecat.services.dograh.llm import DograhLLMService
from pipecat.services.dograh.stt import DograhSTTService
from pipecat.services.dograh.tts import DograhTTSService
from pipecat.services.openai.base_llm import OpenAILLMSettings
from websockets.protocol import State
class _FakeWebSocket:
def __init__(self):
self.state = State.OPEN
self.messages: list[dict] = []
async def send(self, message: str) -> None:
self.messages.append(json.loads(message))
async def close(self, *args, **kwargs) -> None:
self.state = State.CLOSED
def test_dograh_llm_uses_explicit_mps_correlation_id():
service = DograhLLMService(
api_key="mps-secret",
correlation_id="mps-corr-123",
settings=OpenAILLMSettings(model="default"),
)
service._start_metadata = {"workflow_run_id": 99}
params = service.build_chat_completion_params(
{
"messages": [],
"tools": OPENAI_NOT_GIVEN,
"tool_choice": OPENAI_NOT_GIVEN,
}
)
assert params["metadata"]["correlation_id"] == "mps-corr-123"
assert params["metadata"]["mps_billing_version"] == "2"
@pytest.mark.asyncio
async def test_dograh_stt_config_uses_explicit_mps_correlation_id(monkeypatch):
fake_ws = _FakeWebSocket()
async def fake_connect(url, additional_headers):
return fake_ws
monkeypatch.setattr(
"pipecat.services.dograh.stt.websocket_connect",
fake_connect,
)
service = DograhSTTService(
api_key="mps-secret",
correlation_id="mps-corr-123",
sample_rate=16000,
)
service._start_metadata = {"workflow_run_id": 99}
await service._connect_websocket()
assert fake_ws.messages[0]["type"] == "config"
assert fake_ws.messages[0]["correlation_id"] == "mps-corr-123"
assert fake_ws.messages[0]["mps_billing_version"] == "2"
@pytest.mark.asyncio
async def test_dograh_tts_messages_use_explicit_mps_correlation_id(monkeypatch):
fake_ws = _FakeWebSocket()
async def fake_connect(url, additional_headers):
return fake_ws
monkeypatch.setattr(
"pipecat.services.dograh.tts.websocket_connect",
fake_connect,
)
service = DograhTTSService(
api_key="mps-secret",
correlation_id="mps-corr-123",
sample_rate=24000,
)
service._start_metadata = {"workflow_run_id": 99}
await service._connect_websocket()
assert fake_ws.messages[0]["type"] == "config"
assert fake_ws.messages[0]["correlation_id"] == "mps-corr-123"
assert fake_ws.messages[0]["mps_billing_version"] == "2"
async def _noop(*args, **kwargs):
return None
service.audio_context_available = lambda context_id: False
service.create_audio_context = _noop
service.start_ttfb_metrics = _noop
service.start_tts_usage_metrics = _noop
frames = []
async for frame in service.run_tts("hello", "ctx-1"):
frames.append(frame)
assert isinstance(frames[0], TTSStartedFrame)
assert fake_ws.messages[1]["type"] == "create_context"
assert fake_ws.messages[1]["correlation_id"] == "mps-corr-123"
assert fake_ws.messages[1]["mps_billing_version"] == "2"

View file

@ -7,7 +7,7 @@ from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.xai.realtime import events
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.configuration.registry import GrokRealtimeLLMConfiguration
from api.services.pipecat.realtime.grok_realtime import (
DograhGrokRealtimeLLMService,
@ -120,7 +120,7 @@ async def test_completed_input_transcription_is_broadcast_as_finalized():
def test_factory_creates_dograh_grok_realtime_service():
user_config = EffectiveAIModelConfiguration(
effective_config = EffectiveAIModelConfiguration(
is_realtime=True,
realtime=GrokRealtimeLLMConfiguration(
provider="grok_realtime",
@ -131,7 +131,7 @@ def test_factory_creates_dograh_grok_realtime_service():
)
service = create_realtime_llm_service(
user_config,
effective_config,
audio_config=SimpleNamespace(),
)

View file

@ -5,7 +5,7 @@ from fastapi import FastAPI
from fastapi.testclient import TestClient
from api.routes.user import router
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.auth.depends import get_user
from api.services.configuration.masking import mask_key
from api.services.configuration.registry import (

View file

@ -87,3 +87,317 @@ async def test_check_service_key_usage_uses_bearer_self_usage(monkeypatch):
"Content-Type": "application/json",
},
)
@pytest.mark.asyncio
async def test_create_correlation_id_uses_bearer_auth(monkeypatch):
calls = []
class FakeAsyncClient:
def __init__(self, timeout):
self.timeout = timeout
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def post(self, url, json, headers):
calls.append(("POST", url, json, headers))
return _Response(200, {"correlation_id": "mps-corr-123"})
monkeypatch.setattr(
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
)
client = MPSServiceKeyClient()
assert await client.create_correlation_id(
service_key="mps_sk_paid",
workflow_run_id=42,
) == {"correlation_id": "mps-corr-123"}
assert calls == [
(
"POST",
f"{client.base_url}/api/v1/service-keys/correlation-id/self",
{"workflow_run_id": 42},
{
"Authorization": "Bearer mps_sk_paid",
"Content-Type": "application/json",
},
)
]
@pytest.mark.asyncio
async def test_get_billing_account_status_uses_hosted_org_auth(monkeypatch):
calls = []
class FakeAsyncClient:
def __init__(self, timeout):
self.timeout = timeout
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def get(self, url, headers):
calls.append(("GET", url, headers))
return _Response(200, {"organization_id": 42, "billing_mode": "v2"})
monkeypatch.setattr(
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
)
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
)
client = MPSServiceKeyClient()
assert await client.get_billing_account_status(organization_id=42) == {
"organization_id": 42,
"billing_mode": "v2",
}
assert calls == [
(
"GET",
f"{client.base_url}/api/v1/billing/accounts/42/status",
{
"Content-Type": "application/json",
"X-Secret-Key": "mps-secret",
"X-Organization-Id": "42",
},
)
]
@pytest.mark.asyncio
async def test_ensure_billing_account_v2_uses_balance_endpoint(monkeypatch):
calls = []
class FakeAsyncClient:
def __init__(self, timeout):
self.timeout = timeout
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def get(self, url, headers):
calls.append(("GET", url, headers))
return _Response(
200,
{
"id": 7,
"organization_id": 42,
"billing_mode": "v2",
"cached_balance_credits": "0.0000",
"currency": "USD",
},
)
monkeypatch.setattr(
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
)
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
)
client = MPSServiceKeyClient()
assert await client.ensure_billing_account_v2(
organization_id=42,
created_by="provider-123",
) == {
"id": 7,
"organization_id": 42,
"billing_mode": "v2",
"cached_balance_credits": "0.0000",
"currency": "USD",
}
assert calls == [
(
"GET",
f"{client.base_url}/api/v1/billing/accounts/42/balance",
{
"Content-Type": "application/json",
"X-Secret-Key": "mps-secret",
"X-Organization-Id": "42",
},
)
]
@pytest.mark.asyncio
async def test_get_credit_ledger_sends_page_and_limit(monkeypatch):
calls = []
class FakeAsyncClient:
def __init__(self, timeout):
self.timeout = timeout
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def get(self, url, params, headers):
calls.append(("GET", url, params, headers))
return _Response(
200,
{
"account": {"organization_id": 42},
"ledger_entries": [],
"total_count": 0,
"page": 3,
"limit": 25,
"total_pages": 0,
},
)
monkeypatch.setattr(
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
)
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
)
client = MPSServiceKeyClient()
assert await client.get_credit_ledger(
organization_id=42,
page=3,
limit=25,
) == {
"account": {"organization_id": 42},
"ledger_entries": [],
"total_count": 0,
"page": 3,
"limit": 25,
"total_pages": 0,
}
assert calls == [
(
"GET",
f"{client.base_url}/api/v1/billing/accounts/42/ledger",
{"page": 3, "limit": 25},
{
"Content-Type": "application/json",
"X-Secret-Key": "mps-secret",
"X-Organization-Id": "42",
},
)
]
@pytest.mark.asyncio
async def test_report_platform_usage_uses_hosted_secret_auth(monkeypatch):
calls = []
class FakeAsyncClient:
def __init__(self, timeout):
self.timeout = timeout
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def post(self, url, json, headers):
calls.append(("POST", url, json, headers))
return _Response(200, {"metered": True})
monkeypatch.setattr(
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
)
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
)
client = MPSServiceKeyClient()
assert await client.report_platform_usage(
organization_id=42,
correlation_id="mps-corr-123",
workflow_run_id=123,
metadata={"source": "workflow_run_completion"},
) == {"metered": True}
assert calls == [
(
"POST",
f"{client.base_url}/api/v1/billing/accounts/42/platform-usage",
{
"correlation_id": "mps-corr-123",
"workflow_run_id": 123,
"metadata": {"source": "workflow_run_completion"},
},
{
"Content-Type": "application/json",
"X-Secret-Key": "mps-secret",
"X-Organization-Id": "42",
},
)
]
@pytest.mark.asyncio
async def test_report_platform_usage_sends_duration_without_correlation(monkeypatch):
calls = []
class FakeAsyncClient:
def __init__(self, timeout):
self.timeout = timeout
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def post(self, url, json, headers):
calls.append(("POST", url, json, headers))
return _Response(200, {"metered": True})
monkeypatch.setattr(
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
)
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
)
client = MPSServiceKeyClient()
assert await client.report_platform_usage(
organization_id=42,
duration_seconds=87.0,
workflow_run_id=123,
metadata={"source": "workflow_run_completion"},
) == {"metered": True}
assert calls == [
(
"POST",
f"{client.base_url}/api/v1/billing/accounts/42/platform-usage",
{
"duration_seconds": 87.0,
"workflow_run_id": 123,
"metadata": {"source": "workflow_run_completion"},
},
{
"Content-Type": "application/json",
"X-Secret-Key": "mps-secret",
"X-Organization-Id": "42",
},
)
]

View file

@ -0,0 +1,99 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from api.routes import organization_usage
def test_is_mps_billing_v2_depends_only_on_account_mode():
assert organization_usage._is_mps_billing_v2({"billing_mode": "v2"}) is True
assert organization_usage._is_mps_billing_v2({"billing_mode": "v1"}) is False
assert organization_usage._is_mps_billing_v2({"billing_mode": "shadow"}) is False
assert organization_usage._is_mps_billing_v2(None) is False
@pytest.mark.asyncio
async def test_get_mps_billing_account_status_uses_user_provider_id(monkeypatch):
get_status = AsyncMock(return_value={"billing_mode": "v2"})
monkeypatch.setattr(
organization_usage.mps_service_key_client,
"get_billing_account_status",
get_status,
)
user = SimpleNamespace(provider_id="provider-123")
assert await organization_usage._get_mps_billing_account_status(user, 42) == {
"billing_mode": "v2"
}
get_status.assert_awaited_once_with(
organization_id=42,
created_by="provider-123",
)
@pytest.mark.asyncio
async def test_get_billing_credits_pages_v2_ledger(monkeypatch):
monkeypatch.setattr(organization_usage, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
organization_usage,
"_get_mps_billing_account_status",
AsyncMock(return_value={"billing_mode": "v2"}),
)
get_ledger = AsyncMock(
return_value={
"account": {
"id": 7,
"organization_id": 42,
"billing_mode": "v2",
"cached_balance_credits": 250,
"currency": "USD",
},
"ledger_entries": [
{
"id": 99,
"entry_type": "grant",
"origin": "account_creation",
"credits_delta": 250,
"balance_after": 250,
"created_at": "2026-06-12T00:00:00Z",
}
],
"total_debits_credits": 75,
"total_count": 101,
"page": 3,
"limit": 25,
"total_pages": 5,
}
)
monkeypatch.setattr(
organization_usage.mps_service_key_client,
"get_credit_ledger",
get_ledger,
)
user = SimpleNamespace(
provider_id="provider-123",
selected_organization_id=42,
)
response = await organization_usage.get_billing_credits(
page=3,
limit=25,
user=user,
)
get_ledger.assert_awaited_once_with(
organization_id=42,
page=3,
limit=25,
created_by="provider-123",
)
assert response.billing_version == "v2"
assert response.total_credits_used == 75
assert response.total_count == 101
assert response.page == 3
assert response.limit == 25
assert response.total_pages == 5
assert response.ledger_entries[0].id == 99

View file

@ -9,7 +9,7 @@ Module under test: api.services.configuration.resolve
import pytest
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.configuration.masking import (
contains_masked_key,
mask_workflow_configurations,

View file

@ -1,4 +1,4 @@
from api.services.pricing.run_usage_response import format_public_usage_info
from api.services.workflow.run_usage_response import format_public_usage_info
def test_format_public_usage_info():

View file

@ -10,7 +10,7 @@ from pipecat.processors.frame_processor import FrameDirection
from websockets.exceptions import ConnectionClosedError
from websockets.frames import Close
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.configuration.registry import UltravoxRealtimeLLMConfiguration
from api.services.pipecat.realtime.ultravox_realtime import (
_RESUMPTION_USER_MESSAGE,
@ -430,7 +430,7 @@ async def test_receive_messages_reports_unexpected_websocket_close():
def test_factory_creates_dograh_ultravox_realtime_service():
user_config = EffectiveAIModelConfiguration(
effective_config = EffectiveAIModelConfiguration(
is_realtime=True,
realtime=UltravoxRealtimeLLMConfiguration(
provider="ultravox_realtime",
@ -441,7 +441,7 @@ def test_factory_creates_dograh_ultravox_realtime_service():
)
service = create_realtime_llm_service(
user_config,
effective_config,
audio_config=SimpleNamespace(),
)

View file

@ -0,0 +1,212 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from api.services import workflow_run_billing as workflow_run_billing_mod
from api.services.workflow_run_billing import (
report_completed_workflow_run_platform_usage,
report_workflow_run_platform_usage,
)
def _make_workflow_run():
return SimpleNamespace(
id=123,
workflow_id=456,
is_completed=True,
initial_context={"mps_correlation_id": "mps-corr-123"},
usage_info={"call_duration_seconds": 87},
workflow=SimpleNamespace(
organization_id=42,
user=SimpleNamespace(selected_organization_id=42),
),
)
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_reports_hosted_completion(
monkeypatch,
):
workflow_run = _make_workflow_run()
get_status = AsyncMock(return_value={"billing_mode": "v2"})
report_usage = AsyncMock(return_value={"metered": True})
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"get_billing_account_status",
get_status,
)
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_workflow_run_platform_usage(workflow_run)
report_usage.assert_awaited_once_with(
organization_id=42,
correlation_id="mps-corr-123",
duration_seconds=None,
workflow_run_id=workflow_run.id,
metadata={
"source": "workflow_run_completion",
"workflow_id": workflow_run.workflow_id,
"duration_source": "mps_correlation",
},
)
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_reports_duration_without_correlation(
monkeypatch,
):
workflow_run = _make_workflow_run()
workflow_run.initial_context = {}
get_status = AsyncMock(return_value={"billing_mode": "v2"})
report_usage = AsyncMock(return_value={"metered": True})
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"get_billing_account_status",
get_status,
)
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_workflow_run_platform_usage(workflow_run)
report_usage.assert_awaited_once_with(
organization_id=42,
correlation_id=None,
duration_seconds=87.0,
workflow_run_id=workflow_run.id,
metadata={
"source": "workflow_run_completion",
"workflow_id": workflow_run.workflow_id,
"duration_source": "dograh_usage_info",
},
)
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_skips_non_v2_account(monkeypatch):
workflow_run = _make_workflow_run()
get_status = AsyncMock(return_value={"billing_mode": "v1"})
report_usage = AsyncMock()
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"get_billing_account_status",
get_status,
)
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_workflow_run_platform_usage(workflow_run)
get_status.assert_awaited_once_with(organization_id=42)
report_usage.assert_not_awaited()
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_skips_missing_duration_without_correlation(
monkeypatch,
):
workflow_run = _make_workflow_run()
workflow_run.initial_context = {}
workflow_run.usage_info = {}
get_status = AsyncMock(return_value={"billing_mode": "v2"})
report_usage = AsyncMock()
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"get_billing_account_status",
get_status,
)
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_workflow_run_platform_usage(workflow_run)
get_status.assert_not_awaited()
report_usage.assert_not_awaited()
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_skips_oss(monkeypatch):
workflow_run = _make_workflow_run()
report_usage = AsyncMock()
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "oss")
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_workflow_run_platform_usage(workflow_run)
report_usage.assert_not_awaited()
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_skips_incomplete(monkeypatch):
workflow_run = _make_workflow_run()
workflow_run.is_completed = False
report_usage = AsyncMock()
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_workflow_run_platform_usage(workflow_run)
report_usage.assert_not_awaited()
@pytest.mark.asyncio
async def test_report_completed_workflow_run_platform_usage_loads_run(monkeypatch):
workflow_run = _make_workflow_run()
get_run = AsyncMock(return_value=workflow_run)
get_status = AsyncMock(return_value={"billing_mode": "v2"})
report_usage = AsyncMock(return_value={"metered": True})
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
workflow_run_billing_mod.db_client,
"get_workflow_run_by_id",
get_run,
)
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"get_billing_account_status",
get_status,
)
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_completed_workflow_run_platform_usage(workflow_run.id)
get_run.assert_awaited_once_with(workflow_run.id)
report_usage.assert_awaited_once()

View file

@ -1,181 +0,0 @@
from datetime import UTC, datetime
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from api.services.pricing import workflow_run_cost as workflow_run_cost_mod
from api.services.pricing.workflow_run_cost import (
apply_usage_delta_to_organization,
build_workflow_run_cost_info,
calculate_workflow_run_cost,
)
def _make_workflow_run():
return SimpleNamespace(
id=123,
workflow_id=456,
mode="textchat",
created_at=datetime.now(UTC),
usage_info={
"llm": {},
"tts": {},
"stt": {},
"call_duration_seconds": 7,
},
cost_info={},
workflow=SimpleNamespace(
organization_id=42,
user=SimpleNamespace(selected_organization_id=42),
),
)
@pytest.mark.asyncio
async def test_build_workflow_run_cost_info_does_not_update_org_usage(monkeypatch):
workflow_run = _make_workflow_run()
get_org = AsyncMock(return_value=SimpleNamespace(id=42, price_per_second_usd=1.5))
update_usage = AsyncMock()
monkeypatch.setattr(
workflow_run_cost_mod.db_client, "get_organization_by_id", get_org
)
monkeypatch.setattr(
workflow_run_cost_mod.db_client, "update_usage_after_run", update_usage
)
cost_info = await build_workflow_run_cost_info(workflow_run)
assert cost_info is not None
assert cost_info["call_duration_seconds"] == 7
assert "cost_breakdown" in cost_info
assert "dograh_token_usage" in cost_info
assert cost_info["charge_usd"] == 10.5
update_usage.assert_not_called()
@pytest.mark.asyncio
async def test_calculate_workflow_run_cost_keeps_org_usage_side_effect_in_wrapper(
monkeypatch,
):
workflow_run = _make_workflow_run()
get_org = AsyncMock(return_value=SimpleNamespace(id=42, price_per_second_usd=None))
update_run = AsyncMock()
update_usage = AsyncMock()
monkeypatch.setattr(
workflow_run_cost_mod.db_client,
"get_workflow_run_by_id",
AsyncMock(return_value=workflow_run),
)
monkeypatch.setattr(
workflow_run_cost_mod.db_client, "get_organization_by_id", get_org
)
monkeypatch.setattr(
workflow_run_cost_mod.db_client, "update_workflow_run", update_run
)
monkeypatch.setattr(
workflow_run_cost_mod.db_client, "update_usage_after_run", update_usage
)
await calculate_workflow_run_cost(workflow_run.id)
update_run.assert_awaited_once()
saved_kwargs = update_run.await_args.kwargs
assert saved_kwargs["run_id"] == workflow_run.id
assert "cost_breakdown" in saved_kwargs["cost_info"]
update_usage.assert_awaited_once()
@pytest.mark.asyncio
async def test_apply_usage_delta_to_organization_uses_incremental_costs(
monkeypatch,
):
workflow_run = _make_workflow_run()
workflow_run.cost_info = {"call_id": "preserve-me"}
usage_delta_one = {
"llm": {
"OpenAILLMService#0|||gpt-4.1-mini": {
"prompt_tokens": 1_000,
"completion_tokens": 100,
"total_tokens": 1_100,
"cache_read_input_tokens": 0,
"cache_creation_input_tokens": 0,
}
},
"tts": {},
"stt": {},
"call_duration_seconds": 3,
}
usage_delta_two = {
"llm": {
"OpenAILLMService#0|||gpt-4.1-mini": {
"prompt_tokens": 2_000,
"completion_tokens": 50,
"total_tokens": 2_050,
"cache_read_input_tokens": 0,
"cache_creation_input_tokens": 0,
}
},
"tts": {},
"stt": {},
"call_duration_seconds": 4,
}
merged_usage = {
"llm": {
"OpenAILLMService#0|||gpt-4.1-mini": {
"prompt_tokens": 3_000,
"completion_tokens": 150,
"total_tokens": 3_150,
"cache_read_input_tokens": 0,
"cache_creation_input_tokens": 0,
}
},
"tts": {},
"stt": {},
"call_duration_seconds": 7,
}
get_org = AsyncMock(return_value=SimpleNamespace(id=42, price_per_second_usd=1.5))
update_usage = AsyncMock()
monkeypatch.setattr(
workflow_run_cost_mod.db_client, "get_organization_by_id", get_org
)
monkeypatch.setattr(
workflow_run_cost_mod.db_client, "update_usage_after_run", update_usage
)
first_delta = await apply_usage_delta_to_organization(workflow_run, usage_delta_one)
second_delta = await apply_usage_delta_to_organization(
workflow_run, usage_delta_two
)
total_workflow_run = SimpleNamespace(**workflow_run.__dict__)
total_workflow_run.usage_info = merged_usage
total_cost = await build_workflow_run_cost_info(total_workflow_run)
assert first_delta is not None
assert second_delta is not None
assert total_cost is not None
assert update_usage.await_count == 2
assert update_usage.await_args_list[0].args == (
42,
first_delta["dograh_token_usage"],
3.0,
first_delta["charge_usd"],
)
assert update_usage.await_args_list[1].args == (
42,
second_delta["dograh_token_usage"],
4.0,
second_delta["charge_usd"],
)
assert (
first_delta["dograh_token_usage"] + second_delta["dograh_token_usage"]
) == pytest.approx(total_cost["dograh_token_usage"])
assert (
first_delta["charge_usd"] + second_delta["charge_usd"]
== total_cost["charge_usd"]
)

View file

@ -4,7 +4,7 @@ from unittest.mock import AsyncMock, patch
import pytest
from api.db.models import OrganizationModel, UserModel
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.tests.integrations._run_pipeline_helpers import USER_CONFIGURATION
from pipecat.tests import MockLLMService
@ -176,11 +176,7 @@ async def test_text_chat_session_creation_executes_initial_assistant_turn(
assert "Start" in (created["gathered_context"] or {}).get("nodes_visited", [])
workflow_run = await db_session.get_workflow_run_by_id(created["workflow_run_id"])
assert workflow_run is not None
assert workflow_run.cost_info[
"call_duration_seconds"
] == workflow_run.usage_info.get("call_duration_seconds", 0)
assert "cost_breakdown" in workflow_run.cost_info
assert "dograh_token_usage" in workflow_run.cost_info
assert "call_duration_seconds" in workflow_run.usage_info
assert _log_texts(run_payload["logs"], "rtf-bot-text") == [
"Hello from the workflow tester."
]
@ -296,11 +292,7 @@ async def test_text_chat_message_executes_assistant_turn(
assert "Start" in (payload["gathered_context"] or {}).get("nodes_visited", [])
workflow_run = await db_session.get_workflow_run_by_id(created["workflow_run_id"])
assert workflow_run is not None
assert workflow_run.cost_info[
"call_duration_seconds"
] == workflow_run.usage_info.get("call_duration_seconds", 0)
assert "cost_breakdown" in workflow_run.cost_info
assert "dograh_token_usage" in workflow_run.cost_info
assert "call_duration_seconds" in workflow_run.usage_info
assert _log_texts(run_payload["logs"], "rtf-user-transcription") == ["Hi there"]
assert _log_texts(run_payload["logs"], "rtf-bot-text") == [
"Welcome to the workflow tester.",