mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-13 08:15:21 +02:00
feat: billing and credit management v2 (#429)
* feat: use mps generated correlation ID * chore: update pipecat submodule * feat: add credit purchase URL * feat: carve out billing page and show credit ledger * feat: deprecate dograh based quota tracking * fix: remove cost calculation from dograh codebase * fix: create mps account on migrate to v2 * chore: update pipecat
This commit is contained in:
parent
97d7103480
commit
1f1149f4d5
80 changed files with 3335 additions and 2057 deletions
|
|
@ -203,7 +203,7 @@ async def create_workflow_run_rows(
|
|||
Returns:
|
||||
Tuple of (workflow_run, user, workflow).
|
||||
"""
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
|
||||
org = OrganizationModel(provider_id=f"test-org-{provider_id_suffix}")
|
||||
async_session.add(org)
|
||||
|
|
|
|||
|
|
@ -1,12 +1,16 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from api.schemas.ai_model_configuration import (
|
||||
DograhManagedAIModelConfiguration,
|
||||
EffectiveAIModelConfiguration,
|
||||
OrganizationAIModelConfigurationResponse,
|
||||
OrganizationAIModelConfigurationV2,
|
||||
compile_ai_model_configuration_v2,
|
||||
)
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY,
|
||||
check_for_masked_keys_in_ai_model_configuration_v2,
|
||||
|
|
@ -15,6 +19,7 @@ from api.services.configuration.ai_model_configuration import (
|
|||
merge_ai_model_configuration_v2_secrets,
|
||||
migrate_workflow_configuration_model_override_to_v2,
|
||||
)
|
||||
from api.services.configuration.check_validity import UserConfigurationValidator
|
||||
from api.services.configuration.masking import mask_key
|
||||
from api.services.configuration.registry import (
|
||||
DeepgramSTTConfiguration,
|
||||
|
|
@ -22,6 +27,8 @@ from api.services.configuration.registry import (
|
|||
DograhSTTService,
|
||||
DograhTTSService,
|
||||
ElevenlabsTTSConfiguration,
|
||||
GoogleLLMService,
|
||||
GoogleRealtimeLLMConfiguration,
|
||||
OpenAIEmbeddingsConfiguration,
|
||||
OpenAILLMService,
|
||||
)
|
||||
|
|
@ -49,6 +56,7 @@ def test_dograh_v2_compiles_to_effective_managed_pipeline_with_embeddings():
|
|||
assert effective.stt.language == "multi"
|
||||
assert effective.embeddings.provider == "dograh"
|
||||
assert effective.embeddings.model == "default"
|
||||
assert effective.managed_service_version == 2
|
||||
|
||||
|
||||
def test_dograh_v2_rejects_non_predefined_speed():
|
||||
|
|
@ -92,6 +100,67 @@ def test_byok_v2_rejects_dograh_provider():
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_byok_realtime_validator_does_not_require_stt_or_tts():
|
||||
config = OrganizationAIModelConfigurationV2.model_validate(
|
||||
{
|
||||
"mode": "byok",
|
||||
"byok": {
|
||||
"mode": "realtime",
|
||||
"realtime": {
|
||||
"realtime": {
|
||||
"provider": "google_realtime",
|
||||
"api_key": "google-realtime-key",
|
||||
"model": "gemini-3.1-flash-live-preview",
|
||||
"voice": "Puck",
|
||||
"language": "en",
|
||||
},
|
||||
"llm": {
|
||||
"provider": "google",
|
||||
"api_key": "google-llm-key",
|
||||
"model": "gemini-2.0-flash",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
effective = compile_ai_model_configuration_v2(config)
|
||||
|
||||
assert effective.is_realtime is True
|
||||
assert effective.stt is None
|
||||
assert effective.tts is None
|
||||
assert await UserConfigurationValidator().validate(effective) == {
|
||||
"status": [{"model": "all", "message": "ok"}]
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_validator_requires_stt_and_tts_when_not_realtime():
|
||||
effective = EffectiveAIModelConfiguration(
|
||||
llm=GoogleLLMService(
|
||||
provider="google",
|
||||
api_key="google-llm-key",
|
||||
model="gemini-2.0-flash",
|
||||
),
|
||||
realtime=GoogleRealtimeLLMConfiguration(
|
||||
provider="google_realtime",
|
||||
api_key="google-realtime-key",
|
||||
model="gemini-3.1-flash-live-preview",
|
||||
voice="Puck",
|
||||
language="en",
|
||||
),
|
||||
is_realtime=False,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await UserConfigurationValidator().validate(effective)
|
||||
|
||||
assert exc_info.value.args[0] == [
|
||||
{"model": "stt", "message": "API key is missing"},
|
||||
{"model": "tts", "message": "API key is missing"},
|
||||
]
|
||||
|
||||
|
||||
def test_masked_dograh_key_is_preserved_when_saving_same_mode():
|
||||
existing = OrganizationAIModelConfigurationV2(
|
||||
mode="dograh",
|
||||
|
|
@ -293,3 +362,98 @@ def test_workflow_model_override_migration_removes_invalid_v1_override_marker():
|
|||
assert changed is True
|
||||
assert "model_overrides" not in migrated
|
||||
assert migrated["ambient_noise_configuration"] == {"enabled": False}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migrate_model_configuration_v2_initializes_hosted_mps_billing(
|
||||
monkeypatch,
|
||||
):
|
||||
from api.routes import organization as organization_routes
|
||||
|
||||
legacy = EffectiveAIModelConfiguration(
|
||||
llm=DograhLLMService(
|
||||
provider="dograh",
|
||||
api_key=["mps-secret"],
|
||||
model="default",
|
||||
),
|
||||
tts=DograhTTSService(
|
||||
provider="dograh",
|
||||
api_key=["mps-secret"],
|
||||
model="default",
|
||||
voice="default",
|
||||
),
|
||||
stt=DograhSTTService(
|
||||
provider="dograh",
|
||||
api_key=["mps-secret"],
|
||||
model="default",
|
||||
),
|
||||
)
|
||||
expected_response = OrganizationAIModelConfigurationResponse(
|
||||
configuration={"version": 2, "mode": "dograh"},
|
||||
effective_configuration={},
|
||||
source="organization_v2",
|
||||
)
|
||||
|
||||
class FakeValidator:
|
||||
async def validate(self, *args, **kwargs):
|
||||
return {"status": [{"model": "all", "message": "ok"}]}
|
||||
|
||||
ensure_billing = AsyncMock(return_value={"billing_mode": "v2"})
|
||||
upsert = AsyncMock()
|
||||
migrate_workflows = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(organization_routes, "DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
organization_routes,
|
||||
"get_organization_ai_model_configuration_v2",
|
||||
AsyncMock(return_value=None),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
organization_routes.db_client,
|
||||
"get_user_configurations",
|
||||
AsyncMock(return_value=legacy),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
organization_routes,
|
||||
"UserConfigurationValidator",
|
||||
lambda: FakeValidator(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
organization_routes,
|
||||
"ensure_hosted_mps_billing_account_v2",
|
||||
ensure_billing,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
organization_routes,
|
||||
"upsert_organization_ai_model_configuration_v2",
|
||||
upsert,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
organization_routes,
|
||||
"migrate_workflow_model_configurations_to_v2",
|
||||
migrate_workflows,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
organization_routes,
|
||||
"_model_configuration_v2_response",
|
||||
AsyncMock(return_value=expected_response),
|
||||
)
|
||||
|
||||
user = SimpleNamespace(
|
||||
id=7,
|
||||
provider_id="provider-123",
|
||||
selected_organization_id=42,
|
||||
)
|
||||
|
||||
response = await organization_routes.migrate_model_configuration_v2(
|
||||
force=False,
|
||||
user=user,
|
||||
)
|
||||
|
||||
ensure_billing.assert_awaited_once_with(42, created_by="provider-123")
|
||||
upsert.assert_awaited_once()
|
||||
migrate_workflows.assert_awaited_once_with(
|
||||
organization_id=42,
|
||||
fallback_user_config=legacy,
|
||||
)
|
||||
assert response == expected_response
|
||||
|
|
|
|||
68
api/tests/test_auth_depends.py
Normal file
68
api/tests/test_auth_depends.py
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.auth import depends as auth_depends
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_initializes_hosted_mps_billing_for_new_org(monkeypatch):
|
||||
stack_user = {
|
||||
"id": "stack-user-1",
|
||||
"selected_team_id": "team-1",
|
||||
"primary_email_verified": False,
|
||||
}
|
||||
user = SimpleNamespace(
|
||||
id=7,
|
||||
email=None,
|
||||
provider_id="stack-user-1",
|
||||
selected_organization_id=None,
|
||||
)
|
||||
organization = SimpleNamespace(id=42)
|
||||
existing_config = SimpleNamespace(llm=object(), tts=None, stt=None)
|
||||
|
||||
ensure_billing = AsyncMock(return_value={"billing_mode": "v2"})
|
||||
|
||||
monkeypatch.setattr(auth_depends, "AUTH_PROVIDER", "stack")
|
||||
monkeypatch.setattr(
|
||||
auth_depends.stackauth,
|
||||
"get_user",
|
||||
AsyncMock(return_value=stack_user),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
auth_depends.db_client,
|
||||
"get_or_create_user_by_provider_id",
|
||||
AsyncMock(return_value=(user, False)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
auth_depends.db_client,
|
||||
"get_or_create_organization_by_provider_id",
|
||||
AsyncMock(return_value=(organization, True)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
auth_depends.db_client,
|
||||
"add_user_to_organization",
|
||||
AsyncMock(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
auth_depends.db_client,
|
||||
"update_user_selected_organization",
|
||||
AsyncMock(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
auth_depends.db_client,
|
||||
"get_user_configurations",
|
||||
AsyncMock(return_value=existing_config),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
auth_depends,
|
||||
"ensure_hosted_mps_billing_account_v2",
|
||||
ensure_billing,
|
||||
)
|
||||
|
||||
result = await auth_depends.get_user(authorization="Bearer token")
|
||||
|
||||
assert result is user
|
||||
assert result.selected_organization_id == 42
|
||||
ensure_billing.assert_awaited_once_with(42, created_by="stack-user-1")
|
||||
|
|
@ -1,31 +0,0 @@
|
|||
from api.services.pricing.cost_calculator import cost_calculator
|
||||
|
||||
|
||||
def test_cost_calculator():
|
||||
"""Test function to verify cost calculation works"""
|
||||
sample_usage = {
|
||||
"llm": {
|
||||
"OpenAILLMService#0|||gpt-4.1-mini": {
|
||||
"prompt_tokens": 45380,
|
||||
"completion_tokens": 496,
|
||||
"total_tokens": 45876,
|
||||
"cache_read_input_tokens": 0,
|
||||
"cache_creation_input_tokens": 0,
|
||||
}
|
||||
},
|
||||
"tts": {"ElevenLabsTTSService#0|||eleven_flash_v2_5": 2399},
|
||||
"stt": {"DeepgramSTTService#0|||nova-3-general": 177.21536946296692},
|
||||
"call_duration_seconds": 179,
|
||||
}
|
||||
|
||||
result = cost_calculator.calculate_total_cost(sample_usage)
|
||||
assert result["llm_cost"] == 45380 * 0.40 / 1_000_000 + 496 * 1.60 / 1_000_000
|
||||
assert result["tts_cost"] == 2399 * 0.0256 / 1_000
|
||||
assert result["stt_cost"] == 177.21536946296692 / 60 * 0.0077
|
||||
assert (
|
||||
abs(
|
||||
result["total"]
|
||||
- (result["llm_cost"] + result["tts_cost"] + result["stt_cost"])
|
||||
)
|
||||
< 1e-10
|
||||
)
|
||||
110
api/tests/test_dograh_managed_correlation.py
Normal file
110
api/tests/test_dograh_managed_correlation.py
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
import json
|
||||
|
||||
import pytest
|
||||
from openai._types import NOT_GIVEN as OPENAI_NOT_GIVEN
|
||||
from pipecat.frames.frames import TTSStartedFrame
|
||||
from pipecat.services.dograh.llm import DograhLLMService
|
||||
from pipecat.services.dograh.stt import DograhSTTService
|
||||
from pipecat.services.dograh.tts import DograhTTSService
|
||||
from pipecat.services.openai.base_llm import OpenAILLMSettings
|
||||
from websockets.protocol import State
|
||||
|
||||
|
||||
class _FakeWebSocket:
|
||||
def __init__(self):
|
||||
self.state = State.OPEN
|
||||
self.messages: list[dict] = []
|
||||
|
||||
async def send(self, message: str) -> None:
|
||||
self.messages.append(json.loads(message))
|
||||
|
||||
async def close(self, *args, **kwargs) -> None:
|
||||
self.state = State.CLOSED
|
||||
|
||||
|
||||
def test_dograh_llm_uses_explicit_mps_correlation_id():
|
||||
service = DograhLLMService(
|
||||
api_key="mps-secret",
|
||||
correlation_id="mps-corr-123",
|
||||
settings=OpenAILLMSettings(model="default"),
|
||||
)
|
||||
service._start_metadata = {"workflow_run_id": 99}
|
||||
|
||||
params = service.build_chat_completion_params(
|
||||
{
|
||||
"messages": [],
|
||||
"tools": OPENAI_NOT_GIVEN,
|
||||
"tool_choice": OPENAI_NOT_GIVEN,
|
||||
}
|
||||
)
|
||||
|
||||
assert params["metadata"]["correlation_id"] == "mps-corr-123"
|
||||
assert params["metadata"]["mps_billing_version"] == "2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dograh_stt_config_uses_explicit_mps_correlation_id(monkeypatch):
|
||||
fake_ws = _FakeWebSocket()
|
||||
|
||||
async def fake_connect(url, additional_headers):
|
||||
return fake_ws
|
||||
|
||||
monkeypatch.setattr(
|
||||
"pipecat.services.dograh.stt.websocket_connect",
|
||||
fake_connect,
|
||||
)
|
||||
|
||||
service = DograhSTTService(
|
||||
api_key="mps-secret",
|
||||
correlation_id="mps-corr-123",
|
||||
sample_rate=16000,
|
||||
)
|
||||
service._start_metadata = {"workflow_run_id": 99}
|
||||
|
||||
await service._connect_websocket()
|
||||
|
||||
assert fake_ws.messages[0]["type"] == "config"
|
||||
assert fake_ws.messages[0]["correlation_id"] == "mps-corr-123"
|
||||
assert fake_ws.messages[0]["mps_billing_version"] == "2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dograh_tts_messages_use_explicit_mps_correlation_id(monkeypatch):
|
||||
fake_ws = _FakeWebSocket()
|
||||
|
||||
async def fake_connect(url, additional_headers):
|
||||
return fake_ws
|
||||
|
||||
monkeypatch.setattr(
|
||||
"pipecat.services.dograh.tts.websocket_connect",
|
||||
fake_connect,
|
||||
)
|
||||
|
||||
service = DograhTTSService(
|
||||
api_key="mps-secret",
|
||||
correlation_id="mps-corr-123",
|
||||
sample_rate=24000,
|
||||
)
|
||||
service._start_metadata = {"workflow_run_id": 99}
|
||||
|
||||
await service._connect_websocket()
|
||||
assert fake_ws.messages[0]["type"] == "config"
|
||||
assert fake_ws.messages[0]["correlation_id"] == "mps-corr-123"
|
||||
assert fake_ws.messages[0]["mps_billing_version"] == "2"
|
||||
|
||||
async def _noop(*args, **kwargs):
|
||||
return None
|
||||
|
||||
service.audio_context_available = lambda context_id: False
|
||||
service.create_audio_context = _noop
|
||||
service.start_ttfb_metrics = _noop
|
||||
service.start_tts_usage_metrics = _noop
|
||||
|
||||
frames = []
|
||||
async for frame in service.run_tts("hello", "ctx-1"):
|
||||
frames.append(frame)
|
||||
|
||||
assert isinstance(frames[0], TTSStartedFrame)
|
||||
assert fake_ws.messages[1]["type"] == "create_context"
|
||||
assert fake_ws.messages[1]["correlation_id"] == "mps-corr-123"
|
||||
assert fake_ws.messages[1]["mps_billing_version"] == "2"
|
||||
|
|
@ -7,7 +7,7 @@ from pipecat.processors.aggregators.llm_context import LLMContext
|
|||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.xai.realtime import events
|
||||
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.registry import GrokRealtimeLLMConfiguration
|
||||
from api.services.pipecat.realtime.grok_realtime import (
|
||||
DograhGrokRealtimeLLMService,
|
||||
|
|
@ -120,7 +120,7 @@ async def test_completed_input_transcription_is_broadcast_as_finalized():
|
|||
|
||||
|
||||
def test_factory_creates_dograh_grok_realtime_service():
|
||||
user_config = EffectiveAIModelConfiguration(
|
||||
effective_config = EffectiveAIModelConfiguration(
|
||||
is_realtime=True,
|
||||
realtime=GrokRealtimeLLMConfiguration(
|
||||
provider="grok_realtime",
|
||||
|
|
@ -131,7 +131,7 @@ def test_factory_creates_dograh_grok_realtime_service():
|
|||
)
|
||||
|
||||
service = create_realtime_llm_service(
|
||||
user_config,
|
||||
effective_config,
|
||||
audio_config=SimpleNamespace(),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from fastapi import FastAPI
|
|||
from fastapi.testclient import TestClient
|
||||
|
||||
from api.routes.user import router
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.configuration.masking import mask_key
|
||||
from api.services.configuration.registry import (
|
||||
|
|
|
|||
|
|
@ -87,3 +87,317 @@ async def test_check_service_key_usage_uses_bearer_self_usage(monkeypatch):
|
|||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_correlation_id_uses_bearer_auth(monkeypatch):
|
||||
calls = []
|
||||
|
||||
class FakeAsyncClient:
|
||||
def __init__(self, timeout):
|
||||
self.timeout = timeout
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
async def post(self, url, json, headers):
|
||||
calls.append(("POST", url, json, headers))
|
||||
return _Response(200, {"correlation_id": "mps-corr-123"})
|
||||
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
|
||||
)
|
||||
|
||||
client = MPSServiceKeyClient()
|
||||
|
||||
assert await client.create_correlation_id(
|
||||
service_key="mps_sk_paid",
|
||||
workflow_run_id=42,
|
||||
) == {"correlation_id": "mps-corr-123"}
|
||||
assert calls == [
|
||||
(
|
||||
"POST",
|
||||
f"{client.base_url}/api/v1/service-keys/correlation-id/self",
|
||||
{"workflow_run_id": 42},
|
||||
{
|
||||
"Authorization": "Bearer mps_sk_paid",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_billing_account_status_uses_hosted_org_auth(monkeypatch):
|
||||
calls = []
|
||||
|
||||
class FakeAsyncClient:
|
||||
def __init__(self, timeout):
|
||||
self.timeout = timeout
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
async def get(self, url, headers):
|
||||
calls.append(("GET", url, headers))
|
||||
return _Response(200, {"organization_id": 42, "billing_mode": "v2"})
|
||||
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
|
||||
)
|
||||
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
|
||||
)
|
||||
|
||||
client = MPSServiceKeyClient()
|
||||
|
||||
assert await client.get_billing_account_status(organization_id=42) == {
|
||||
"organization_id": 42,
|
||||
"billing_mode": "v2",
|
||||
}
|
||||
assert calls == [
|
||||
(
|
||||
"GET",
|
||||
f"{client.base_url}/api/v1/billing/accounts/42/status",
|
||||
{
|
||||
"Content-Type": "application/json",
|
||||
"X-Secret-Key": "mps-secret",
|
||||
"X-Organization-Id": "42",
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_billing_account_v2_uses_balance_endpoint(monkeypatch):
|
||||
calls = []
|
||||
|
||||
class FakeAsyncClient:
|
||||
def __init__(self, timeout):
|
||||
self.timeout = timeout
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
async def get(self, url, headers):
|
||||
calls.append(("GET", url, headers))
|
||||
return _Response(
|
||||
200,
|
||||
{
|
||||
"id": 7,
|
||||
"organization_id": 42,
|
||||
"billing_mode": "v2",
|
||||
"cached_balance_credits": "0.0000",
|
||||
"currency": "USD",
|
||||
},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
|
||||
)
|
||||
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
|
||||
)
|
||||
|
||||
client = MPSServiceKeyClient()
|
||||
|
||||
assert await client.ensure_billing_account_v2(
|
||||
organization_id=42,
|
||||
created_by="provider-123",
|
||||
) == {
|
||||
"id": 7,
|
||||
"organization_id": 42,
|
||||
"billing_mode": "v2",
|
||||
"cached_balance_credits": "0.0000",
|
||||
"currency": "USD",
|
||||
}
|
||||
assert calls == [
|
||||
(
|
||||
"GET",
|
||||
f"{client.base_url}/api/v1/billing/accounts/42/balance",
|
||||
{
|
||||
"Content-Type": "application/json",
|
||||
"X-Secret-Key": "mps-secret",
|
||||
"X-Organization-Id": "42",
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_credit_ledger_sends_page_and_limit(monkeypatch):
|
||||
calls = []
|
||||
|
||||
class FakeAsyncClient:
|
||||
def __init__(self, timeout):
|
||||
self.timeout = timeout
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
async def get(self, url, params, headers):
|
||||
calls.append(("GET", url, params, headers))
|
||||
return _Response(
|
||||
200,
|
||||
{
|
||||
"account": {"organization_id": 42},
|
||||
"ledger_entries": [],
|
||||
"total_count": 0,
|
||||
"page": 3,
|
||||
"limit": 25,
|
||||
"total_pages": 0,
|
||||
},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
|
||||
)
|
||||
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
|
||||
)
|
||||
|
||||
client = MPSServiceKeyClient()
|
||||
|
||||
assert await client.get_credit_ledger(
|
||||
organization_id=42,
|
||||
page=3,
|
||||
limit=25,
|
||||
) == {
|
||||
"account": {"organization_id": 42},
|
||||
"ledger_entries": [],
|
||||
"total_count": 0,
|
||||
"page": 3,
|
||||
"limit": 25,
|
||||
"total_pages": 0,
|
||||
}
|
||||
assert calls == [
|
||||
(
|
||||
"GET",
|
||||
f"{client.base_url}/api/v1/billing/accounts/42/ledger",
|
||||
{"page": 3, "limit": 25},
|
||||
{
|
||||
"Content-Type": "application/json",
|
||||
"X-Secret-Key": "mps-secret",
|
||||
"X-Organization-Id": "42",
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_platform_usage_uses_hosted_secret_auth(monkeypatch):
|
||||
calls = []
|
||||
|
||||
class FakeAsyncClient:
|
||||
def __init__(self, timeout):
|
||||
self.timeout = timeout
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
async def post(self, url, json, headers):
|
||||
calls.append(("POST", url, json, headers))
|
||||
return _Response(200, {"metered": True})
|
||||
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
|
||||
)
|
||||
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
|
||||
)
|
||||
|
||||
client = MPSServiceKeyClient()
|
||||
|
||||
assert await client.report_platform_usage(
|
||||
organization_id=42,
|
||||
correlation_id="mps-corr-123",
|
||||
workflow_run_id=123,
|
||||
metadata={"source": "workflow_run_completion"},
|
||||
) == {"metered": True}
|
||||
assert calls == [
|
||||
(
|
||||
"POST",
|
||||
f"{client.base_url}/api/v1/billing/accounts/42/platform-usage",
|
||||
{
|
||||
"correlation_id": "mps-corr-123",
|
||||
"workflow_run_id": 123,
|
||||
"metadata": {"source": "workflow_run_completion"},
|
||||
},
|
||||
{
|
||||
"Content-Type": "application/json",
|
||||
"X-Secret-Key": "mps-secret",
|
||||
"X-Organization-Id": "42",
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_platform_usage_sends_duration_without_correlation(monkeypatch):
|
||||
calls = []
|
||||
|
||||
class FakeAsyncClient:
|
||||
def __init__(self, timeout):
|
||||
self.timeout = timeout
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
async def post(self, url, json, headers):
|
||||
calls.append(("POST", url, json, headers))
|
||||
return _Response(200, {"metered": True})
|
||||
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
|
||||
)
|
||||
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
|
||||
)
|
||||
|
||||
client = MPSServiceKeyClient()
|
||||
|
||||
assert await client.report_platform_usage(
|
||||
organization_id=42,
|
||||
duration_seconds=87.0,
|
||||
workflow_run_id=123,
|
||||
metadata={"source": "workflow_run_completion"},
|
||||
) == {"metered": True}
|
||||
assert calls == [
|
||||
(
|
||||
"POST",
|
||||
f"{client.base_url}/api/v1/billing/accounts/42/platform-usage",
|
||||
{
|
||||
"duration_seconds": 87.0,
|
||||
"workflow_run_id": 123,
|
||||
"metadata": {"source": "workflow_run_completion"},
|
||||
},
|
||||
{
|
||||
"Content-Type": "application/json",
|
||||
"X-Secret-Key": "mps-secret",
|
||||
"X-Organization-Id": "42",
|
||||
},
|
||||
)
|
||||
]
|
||||
|
|
|
|||
99
api/tests/test_organization_usage_billing.py
Normal file
99
api/tests/test_organization_usage_billing.py
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from api.routes import organization_usage
|
||||
|
||||
|
||||
def test_is_mps_billing_v2_depends_only_on_account_mode():
|
||||
assert organization_usage._is_mps_billing_v2({"billing_mode": "v2"}) is True
|
||||
assert organization_usage._is_mps_billing_v2({"billing_mode": "v1"}) is False
|
||||
assert organization_usage._is_mps_billing_v2({"billing_mode": "shadow"}) is False
|
||||
assert organization_usage._is_mps_billing_v2(None) is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_mps_billing_account_status_uses_user_provider_id(monkeypatch):
|
||||
get_status = AsyncMock(return_value={"billing_mode": "v2"})
|
||||
monkeypatch.setattr(
|
||||
organization_usage.mps_service_key_client,
|
||||
"get_billing_account_status",
|
||||
get_status,
|
||||
)
|
||||
|
||||
user = SimpleNamespace(provider_id="provider-123")
|
||||
|
||||
assert await organization_usage._get_mps_billing_account_status(user, 42) == {
|
||||
"billing_mode": "v2"
|
||||
}
|
||||
get_status.assert_awaited_once_with(
|
||||
organization_id=42,
|
||||
created_by="provider-123",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_billing_credits_pages_v2_ledger(monkeypatch):
|
||||
monkeypatch.setattr(organization_usage, "DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
organization_usage,
|
||||
"_get_mps_billing_account_status",
|
||||
AsyncMock(return_value={"billing_mode": "v2"}),
|
||||
)
|
||||
get_ledger = AsyncMock(
|
||||
return_value={
|
||||
"account": {
|
||||
"id": 7,
|
||||
"organization_id": 42,
|
||||
"billing_mode": "v2",
|
||||
"cached_balance_credits": 250,
|
||||
"currency": "USD",
|
||||
},
|
||||
"ledger_entries": [
|
||||
{
|
||||
"id": 99,
|
||||
"entry_type": "grant",
|
||||
"origin": "account_creation",
|
||||
"credits_delta": 250,
|
||||
"balance_after": 250,
|
||||
"created_at": "2026-06-12T00:00:00Z",
|
||||
}
|
||||
],
|
||||
"total_debits_credits": 75,
|
||||
"total_count": 101,
|
||||
"page": 3,
|
||||
"limit": 25,
|
||||
"total_pages": 5,
|
||||
}
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
organization_usage.mps_service_key_client,
|
||||
"get_credit_ledger",
|
||||
get_ledger,
|
||||
)
|
||||
|
||||
user = SimpleNamespace(
|
||||
provider_id="provider-123",
|
||||
selected_organization_id=42,
|
||||
)
|
||||
|
||||
response = await organization_usage.get_billing_credits(
|
||||
page=3,
|
||||
limit=25,
|
||||
user=user,
|
||||
)
|
||||
|
||||
get_ledger.assert_awaited_once_with(
|
||||
organization_id=42,
|
||||
page=3,
|
||||
limit=25,
|
||||
created_by="provider-123",
|
||||
)
|
||||
assert response.billing_version == "v2"
|
||||
assert response.total_credits_used == 75
|
||||
assert response.total_count == 101
|
||||
assert response.page == 3
|
||||
assert response.limit == 25
|
||||
assert response.total_pages == 5
|
||||
assert response.ledger_entries[0].id == 99
|
||||
|
|
@ -9,7 +9,7 @@ Module under test: api.services.configuration.resolve
|
|||
|
||||
import pytest
|
||||
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.masking import (
|
||||
contains_masked_key,
|
||||
mask_workflow_configurations,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from api.services.pricing.run_usage_response import format_public_usage_info
|
||||
from api.services.workflow.run_usage_response import format_public_usage_info
|
||||
|
||||
|
||||
def test_format_public_usage_info():
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from pipecat.processors.frame_processor import FrameDirection
|
|||
from websockets.exceptions import ConnectionClosedError
|
||||
from websockets.frames import Close
|
||||
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.registry import UltravoxRealtimeLLMConfiguration
|
||||
from api.services.pipecat.realtime.ultravox_realtime import (
|
||||
_RESUMPTION_USER_MESSAGE,
|
||||
|
|
@ -430,7 +430,7 @@ async def test_receive_messages_reports_unexpected_websocket_close():
|
|||
|
||||
|
||||
def test_factory_creates_dograh_ultravox_realtime_service():
|
||||
user_config = EffectiveAIModelConfiguration(
|
||||
effective_config = EffectiveAIModelConfiguration(
|
||||
is_realtime=True,
|
||||
realtime=UltravoxRealtimeLLMConfiguration(
|
||||
provider="ultravox_realtime",
|
||||
|
|
@ -441,7 +441,7 @@ def test_factory_creates_dograh_ultravox_realtime_service():
|
|||
)
|
||||
|
||||
service = create_realtime_llm_service(
|
||||
user_config,
|
||||
effective_config,
|
||||
audio_config=SimpleNamespace(),
|
||||
)
|
||||
|
||||
|
|
|
|||
212
api/tests/test_workflow_run_billing.py
Normal file
212
api/tests/test_workflow_run_billing.py
Normal file
|
|
@ -0,0 +1,212 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services import workflow_run_billing as workflow_run_billing_mod
|
||||
from api.services.workflow_run_billing import (
|
||||
report_completed_workflow_run_platform_usage,
|
||||
report_workflow_run_platform_usage,
|
||||
)
|
||||
|
||||
|
||||
def _make_workflow_run():
|
||||
return SimpleNamespace(
|
||||
id=123,
|
||||
workflow_id=456,
|
||||
is_completed=True,
|
||||
initial_context={"mps_correlation_id": "mps-corr-123"},
|
||||
usage_info={"call_duration_seconds": 87},
|
||||
workflow=SimpleNamespace(
|
||||
organization_id=42,
|
||||
user=SimpleNamespace(selected_organization_id=42),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_workflow_run_platform_usage_reports_hosted_completion(
|
||||
monkeypatch,
|
||||
):
|
||||
workflow_run = _make_workflow_run()
|
||||
get_status = AsyncMock(return_value={"billing_mode": "v2"})
|
||||
report_usage = AsyncMock(return_value={"metered": True})
|
||||
|
||||
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"get_billing_account_status",
|
||||
get_status,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"report_platform_usage",
|
||||
report_usage,
|
||||
)
|
||||
|
||||
await report_workflow_run_platform_usage(workflow_run)
|
||||
|
||||
report_usage.assert_awaited_once_with(
|
||||
organization_id=42,
|
||||
correlation_id="mps-corr-123",
|
||||
duration_seconds=None,
|
||||
workflow_run_id=workflow_run.id,
|
||||
metadata={
|
||||
"source": "workflow_run_completion",
|
||||
"workflow_id": workflow_run.workflow_id,
|
||||
"duration_source": "mps_correlation",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_workflow_run_platform_usage_reports_duration_without_correlation(
|
||||
monkeypatch,
|
||||
):
|
||||
workflow_run = _make_workflow_run()
|
||||
workflow_run.initial_context = {}
|
||||
get_status = AsyncMock(return_value={"billing_mode": "v2"})
|
||||
report_usage = AsyncMock(return_value={"metered": True})
|
||||
|
||||
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"get_billing_account_status",
|
||||
get_status,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"report_platform_usage",
|
||||
report_usage,
|
||||
)
|
||||
|
||||
await report_workflow_run_platform_usage(workflow_run)
|
||||
|
||||
report_usage.assert_awaited_once_with(
|
||||
organization_id=42,
|
||||
correlation_id=None,
|
||||
duration_seconds=87.0,
|
||||
workflow_run_id=workflow_run.id,
|
||||
metadata={
|
||||
"source": "workflow_run_completion",
|
||||
"workflow_id": workflow_run.workflow_id,
|
||||
"duration_source": "dograh_usage_info",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_workflow_run_platform_usage_skips_non_v2_account(monkeypatch):
|
||||
workflow_run = _make_workflow_run()
|
||||
get_status = AsyncMock(return_value={"billing_mode": "v1"})
|
||||
report_usage = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"get_billing_account_status",
|
||||
get_status,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"report_platform_usage",
|
||||
report_usage,
|
||||
)
|
||||
|
||||
await report_workflow_run_platform_usage(workflow_run)
|
||||
|
||||
get_status.assert_awaited_once_with(organization_id=42)
|
||||
report_usage.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_workflow_run_platform_usage_skips_missing_duration_without_correlation(
|
||||
monkeypatch,
|
||||
):
|
||||
workflow_run = _make_workflow_run()
|
||||
workflow_run.initial_context = {}
|
||||
workflow_run.usage_info = {}
|
||||
get_status = AsyncMock(return_value={"billing_mode": "v2"})
|
||||
report_usage = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"get_billing_account_status",
|
||||
get_status,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"report_platform_usage",
|
||||
report_usage,
|
||||
)
|
||||
|
||||
await report_workflow_run_platform_usage(workflow_run)
|
||||
|
||||
get_status.assert_not_awaited()
|
||||
report_usage.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_workflow_run_platform_usage_skips_oss(monkeypatch):
|
||||
workflow_run = _make_workflow_run()
|
||||
report_usage = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "oss")
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"report_platform_usage",
|
||||
report_usage,
|
||||
)
|
||||
|
||||
await report_workflow_run_platform_usage(workflow_run)
|
||||
|
||||
report_usage.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_workflow_run_platform_usage_skips_incomplete(monkeypatch):
|
||||
workflow_run = _make_workflow_run()
|
||||
workflow_run.is_completed = False
|
||||
report_usage = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"report_platform_usage",
|
||||
report_usage,
|
||||
)
|
||||
|
||||
await report_workflow_run_platform_usage(workflow_run)
|
||||
|
||||
report_usage.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_completed_workflow_run_platform_usage_loads_run(monkeypatch):
|
||||
workflow_run = _make_workflow_run()
|
||||
get_run = AsyncMock(return_value=workflow_run)
|
||||
get_status = AsyncMock(return_value={"billing_mode": "v2"})
|
||||
report_usage = AsyncMock(return_value={"metered": True})
|
||||
|
||||
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.db_client,
|
||||
"get_workflow_run_by_id",
|
||||
get_run,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"get_billing_account_status",
|
||||
get_status,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"report_platform_usage",
|
||||
report_usage,
|
||||
)
|
||||
|
||||
await report_completed_workflow_run_platform_usage(workflow_run.id)
|
||||
|
||||
get_run.assert_awaited_once_with(workflow_run.id)
|
||||
report_usage.assert_awaited_once()
|
||||
|
|
@ -1,181 +0,0 @@
|
|||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.pricing import workflow_run_cost as workflow_run_cost_mod
|
||||
from api.services.pricing.workflow_run_cost import (
|
||||
apply_usage_delta_to_organization,
|
||||
build_workflow_run_cost_info,
|
||||
calculate_workflow_run_cost,
|
||||
)
|
||||
|
||||
|
||||
def _make_workflow_run():
|
||||
return SimpleNamespace(
|
||||
id=123,
|
||||
workflow_id=456,
|
||||
mode="textchat",
|
||||
created_at=datetime.now(UTC),
|
||||
usage_info={
|
||||
"llm": {},
|
||||
"tts": {},
|
||||
"stt": {},
|
||||
"call_duration_seconds": 7,
|
||||
},
|
||||
cost_info={},
|
||||
workflow=SimpleNamespace(
|
||||
organization_id=42,
|
||||
user=SimpleNamespace(selected_organization_id=42),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_workflow_run_cost_info_does_not_update_org_usage(monkeypatch):
|
||||
workflow_run = _make_workflow_run()
|
||||
get_org = AsyncMock(return_value=SimpleNamespace(id=42, price_per_second_usd=1.5))
|
||||
update_usage = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(
|
||||
workflow_run_cost_mod.db_client, "get_organization_by_id", get_org
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_cost_mod.db_client, "update_usage_after_run", update_usage
|
||||
)
|
||||
|
||||
cost_info = await build_workflow_run_cost_info(workflow_run)
|
||||
|
||||
assert cost_info is not None
|
||||
assert cost_info["call_duration_seconds"] == 7
|
||||
assert "cost_breakdown" in cost_info
|
||||
assert "dograh_token_usage" in cost_info
|
||||
assert cost_info["charge_usd"] == 10.5
|
||||
update_usage.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_workflow_run_cost_keeps_org_usage_side_effect_in_wrapper(
|
||||
monkeypatch,
|
||||
):
|
||||
workflow_run = _make_workflow_run()
|
||||
get_org = AsyncMock(return_value=SimpleNamespace(id=42, price_per_second_usd=None))
|
||||
update_run = AsyncMock()
|
||||
update_usage = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(
|
||||
workflow_run_cost_mod.db_client,
|
||||
"get_workflow_run_by_id",
|
||||
AsyncMock(return_value=workflow_run),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_cost_mod.db_client, "get_organization_by_id", get_org
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_cost_mod.db_client, "update_workflow_run", update_run
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_cost_mod.db_client, "update_usage_after_run", update_usage
|
||||
)
|
||||
|
||||
await calculate_workflow_run_cost(workflow_run.id)
|
||||
|
||||
update_run.assert_awaited_once()
|
||||
saved_kwargs = update_run.await_args.kwargs
|
||||
assert saved_kwargs["run_id"] == workflow_run.id
|
||||
assert "cost_breakdown" in saved_kwargs["cost_info"]
|
||||
update_usage.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_usage_delta_to_organization_uses_incremental_costs(
|
||||
monkeypatch,
|
||||
):
|
||||
workflow_run = _make_workflow_run()
|
||||
workflow_run.cost_info = {"call_id": "preserve-me"}
|
||||
|
||||
usage_delta_one = {
|
||||
"llm": {
|
||||
"OpenAILLMService#0|||gpt-4.1-mini": {
|
||||
"prompt_tokens": 1_000,
|
||||
"completion_tokens": 100,
|
||||
"total_tokens": 1_100,
|
||||
"cache_read_input_tokens": 0,
|
||||
"cache_creation_input_tokens": 0,
|
||||
}
|
||||
},
|
||||
"tts": {},
|
||||
"stt": {},
|
||||
"call_duration_seconds": 3,
|
||||
}
|
||||
usage_delta_two = {
|
||||
"llm": {
|
||||
"OpenAILLMService#0|||gpt-4.1-mini": {
|
||||
"prompt_tokens": 2_000,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 2_050,
|
||||
"cache_read_input_tokens": 0,
|
||||
"cache_creation_input_tokens": 0,
|
||||
}
|
||||
},
|
||||
"tts": {},
|
||||
"stt": {},
|
||||
"call_duration_seconds": 4,
|
||||
}
|
||||
merged_usage = {
|
||||
"llm": {
|
||||
"OpenAILLMService#0|||gpt-4.1-mini": {
|
||||
"prompt_tokens": 3_000,
|
||||
"completion_tokens": 150,
|
||||
"total_tokens": 3_150,
|
||||
"cache_read_input_tokens": 0,
|
||||
"cache_creation_input_tokens": 0,
|
||||
}
|
||||
},
|
||||
"tts": {},
|
||||
"stt": {},
|
||||
"call_duration_seconds": 7,
|
||||
}
|
||||
|
||||
get_org = AsyncMock(return_value=SimpleNamespace(id=42, price_per_second_usd=1.5))
|
||||
update_usage = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(
|
||||
workflow_run_cost_mod.db_client, "get_organization_by_id", get_org
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_cost_mod.db_client, "update_usage_after_run", update_usage
|
||||
)
|
||||
|
||||
first_delta = await apply_usage_delta_to_organization(workflow_run, usage_delta_one)
|
||||
second_delta = await apply_usage_delta_to_organization(
|
||||
workflow_run, usage_delta_two
|
||||
)
|
||||
total_workflow_run = SimpleNamespace(**workflow_run.__dict__)
|
||||
total_workflow_run.usage_info = merged_usage
|
||||
total_cost = await build_workflow_run_cost_info(total_workflow_run)
|
||||
|
||||
assert first_delta is not None
|
||||
assert second_delta is not None
|
||||
assert total_cost is not None
|
||||
assert update_usage.await_count == 2
|
||||
assert update_usage.await_args_list[0].args == (
|
||||
42,
|
||||
first_delta["dograh_token_usage"],
|
||||
3.0,
|
||||
first_delta["charge_usd"],
|
||||
)
|
||||
assert update_usage.await_args_list[1].args == (
|
||||
42,
|
||||
second_delta["dograh_token_usage"],
|
||||
4.0,
|
||||
second_delta["charge_usd"],
|
||||
)
|
||||
assert (
|
||||
first_delta["dograh_token_usage"] + second_delta["dograh_token_usage"]
|
||||
) == pytest.approx(total_cost["dograh_token_usage"])
|
||||
assert (
|
||||
first_delta["charge_usd"] + second_delta["charge_usd"]
|
||||
== total_cost["charge_usd"]
|
||||
)
|
||||
|
|
@ -4,7 +4,7 @@ from unittest.mock import AsyncMock, patch
|
|||
import pytest
|
||||
|
||||
from api.db.models import OrganizationModel, UserModel
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.tests.integrations._run_pipeline_helpers import USER_CONFIGURATION
|
||||
from pipecat.tests import MockLLMService
|
||||
|
||||
|
|
@ -176,11 +176,7 @@ async def test_text_chat_session_creation_executes_initial_assistant_turn(
|
|||
assert "Start" in (created["gathered_context"] or {}).get("nodes_visited", [])
|
||||
workflow_run = await db_session.get_workflow_run_by_id(created["workflow_run_id"])
|
||||
assert workflow_run is not None
|
||||
assert workflow_run.cost_info[
|
||||
"call_duration_seconds"
|
||||
] == workflow_run.usage_info.get("call_duration_seconds", 0)
|
||||
assert "cost_breakdown" in workflow_run.cost_info
|
||||
assert "dograh_token_usage" in workflow_run.cost_info
|
||||
assert "call_duration_seconds" in workflow_run.usage_info
|
||||
assert _log_texts(run_payload["logs"], "rtf-bot-text") == [
|
||||
"Hello from the workflow tester."
|
||||
]
|
||||
|
|
@ -296,11 +292,7 @@ async def test_text_chat_message_executes_assistant_turn(
|
|||
assert "Start" in (payload["gathered_context"] or {}).get("nodes_visited", [])
|
||||
workflow_run = await db_session.get_workflow_run_by_id(created["workflow_run_id"])
|
||||
assert workflow_run is not None
|
||||
assert workflow_run.cost_info[
|
||||
"call_duration_seconds"
|
||||
] == workflow_run.usage_info.get("call_duration_seconds", 0)
|
||||
assert "cost_breakdown" in workflow_run.cost_info
|
||||
assert "dograh_token_usage" in workflow_run.cost_info
|
||||
assert "call_duration_seconds" in workflow_run.usage_info
|
||||
assert _log_texts(run_payload["logs"], "rtf-user-transcription") == ["Hi there"]
|
||||
assert _log_texts(run_payload["logs"], "rtf-bot-text") == [
|
||||
"Welcome to the workflow tester.",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue