mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-13 08:15:21 +02:00
fix: create mps account on migrate to v2
This commit is contained in:
parent
8f241b89d2
commit
724e1d456b
14 changed files with 666 additions and 61 deletions
|
|
@ -5,7 +5,11 @@ from loguru import logger
|
|||
from pydantic import BaseModel
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from api.constants import DEFAULT_CAMPAIGN_RETRY_CONFIG, DEFAULT_ORG_CONCURRENCY_LIMIT
|
||||
from api.constants import (
|
||||
DEFAULT_CAMPAIGN_RETRY_CONFIG,
|
||||
DEFAULT_ORG_CONCURRENCY_LIMIT,
|
||||
DEPLOYMENT_MODE,
|
||||
)
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.db.telephony_configuration_client import TelephonyConfigurationInUseError
|
||||
|
|
@ -55,6 +59,7 @@ from api.services.configuration.registry import (
|
|||
ServiceProviders,
|
||||
ServiceType,
|
||||
)
|
||||
from api.services.mps_billing import ensure_hosted_mps_billing_account_v2
|
||||
from api.services.organization_context import (
|
||||
OrganizationContextResponse,
|
||||
get_organization_context,
|
||||
|
|
@ -359,6 +364,23 @@ async def migrate_model_configuration_v2(
|
|||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail=exc.args[0])
|
||||
|
||||
if DEPLOYMENT_MODE != "oss":
|
||||
try:
|
||||
await ensure_hosted_mps_billing_account_v2(
|
||||
organization_id,
|
||||
created_by=str(user.provider_id),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Failed to initialize MPS billing v2 account for organization {}: {}",
|
||||
organization_id,
|
||||
exc,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to initialize MPS billing v2 account",
|
||||
)
|
||||
|
||||
await upsert_organization_ai_model_configuration_v2(
|
||||
organization_id,
|
||||
configuration,
|
||||
|
|
|
|||
|
|
@ -74,6 +74,10 @@ class MPSBillingCreditsResponse(BaseModel):
|
|||
total_quota: float = 0.0
|
||||
account: Optional[MPSBillingAccountResponse] = None
|
||||
ledger_entries: List[MPSCreditLedgerEntryResponse] = Field(default_factory=list)
|
||||
total_count: int = 0
|
||||
page: int = 1
|
||||
limit: int = 50
|
||||
total_pages: int = 0
|
||||
|
||||
|
||||
def _optional_int(value: Any) -> Optional[int]:
|
||||
|
|
@ -224,10 +228,11 @@ async def _legacy_mps_credits_response(user: UserModel) -> MPSBillingCreditsResp
|
|||
|
||||
@router.get("/billing/credits", response_model=MPSBillingCreditsResponse)
|
||||
async def get_billing_credits(
|
||||
page: int = Query(1, ge=1),
|
||||
limit: int = Query(50, ge=1, le=100),
|
||||
user: UserModel = Depends(get_user),
|
||||
):
|
||||
"""Return legacy MPS credits or v2 billing ledger details for the org."""
|
||||
"""Return legacy MPS credits or paginated v2 billing ledger details for the org."""
|
||||
try:
|
||||
if DEPLOYMENT_MODE == "oss" or not user.selected_organization_id:
|
||||
return await _legacy_mps_credits_response(user)
|
||||
|
|
@ -239,11 +244,18 @@ async def get_billing_credits(
|
|||
|
||||
ledger = await mps_service_key_client.get_credit_ledger(
|
||||
organization_id=organization_id,
|
||||
page=page,
|
||||
limit=limit,
|
||||
created_by=str(user.provider_id),
|
||||
)
|
||||
account = ledger.get("account") or {}
|
||||
ledger_entries = ledger.get("ledger_entries") or []
|
||||
total_count = int(ledger.get("total_count") or len(ledger_entries))
|
||||
response_limit = int(ledger.get("limit") or limit)
|
||||
total_pages = int(
|
||||
ledger.get("total_pages")
|
||||
or ((total_count + response_limit - 1) // response_limit)
|
||||
)
|
||||
workflow_ids_by_run_id: dict[int, int] = {}
|
||||
workflow_run_ids = {
|
||||
workflow_run_id
|
||||
|
|
@ -266,6 +278,8 @@ async def get_billing_credits(
|
|||
for entry in ledger_entries
|
||||
if float(entry.get("credits_delta") or 0.0) < 0
|
||||
)
|
||||
if ledger.get("total_debits_credits") is not None:
|
||||
total_debits = float(ledger["total_debits_credits"])
|
||||
|
||||
return MPSBillingCreditsResponse(
|
||||
billing_version="v2",
|
||||
|
|
@ -308,6 +322,10 @@ async def get_billing_credits(
|
|||
)
|
||||
for entry in ledger_entries
|
||||
],
|
||||
total_count=total_count,
|
||||
page=int(ledger.get("page") or page),
|
||||
limit=response_limit,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from api.enums import PostHogEvent
|
|||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.auth.stack_auth import stackauth
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.mps_billing import ensure_hosted_mps_billing_account_v2
|
||||
from api.services.posthog_client import capture_event
|
||||
from api.utils.auth import decode_jwt_token
|
||||
|
||||
|
|
@ -110,6 +111,19 @@ async def get_user(
|
|||
# This prevents race conditions where multiple concurrent requests
|
||||
# might try to create configurations
|
||||
if org_was_created:
|
||||
try:
|
||||
await ensure_hosted_mps_billing_account_v2(
|
||||
organization.id,
|
||||
created_by=str(stack_user["id"]),
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to initialize hosted MPS billing account for "
|
||||
"organization {}",
|
||||
organization.id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
existing_cfg = await db_client.get_user_configurations(user_model.id)
|
||||
if not (existing_cfg.llm or existing_cfg.tts or existing_cfg.stt):
|
||||
mps_config = await create_user_configuration_with_mps_key(
|
||||
|
|
@ -232,7 +246,7 @@ async def create_user_configuration_with_mps_key(
|
|||
response = await client.post(
|
||||
f"{MPS_API_URL}/api/v1/service-keys/",
|
||||
json={
|
||||
"name": f"Default Dograh Model Service Key",
|
||||
"name": "Default Dograh Model Service Key",
|
||||
"description": "Auto-generated key for OSS user",
|
||||
"expires_in_days": 7, # Short-lived for OSS
|
||||
"created_by": user_provider_id,
|
||||
|
|
@ -250,7 +264,7 @@ async def create_user_configuration_with_mps_key(
|
|||
response = await client.post(
|
||||
f"{MPS_API_URL}/api/v1/service-keys/",
|
||||
json={
|
||||
"name": f"Default Dograh Model Service Key",
|
||||
"name": "Default Dograh Model Service Key",
|
||||
"description": f"Auto-generated key for organization {organization_id}",
|
||||
"organization_id": organization_id,
|
||||
"expires_in_days": 90, # Longer-lived for authenticated users
|
||||
|
|
|
|||
23
api/services/mps_billing.py
Normal file
23
api/services/mps_billing.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
from typing import Optional
|
||||
|
||||
from api.constants import DEPLOYMENT_MODE
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
|
||||
|
||||
async def ensure_hosted_mps_billing_account_v2(
|
||||
organization_id: int,
|
||||
*,
|
||||
created_by: Optional[str] = None,
|
||||
) -> Optional[dict]:
|
||||
"""Ensure hosted orgs have an MPS billing v2 account.
|
||||
|
||||
OSS deployments use legacy per-key quota accounting and do not create MPS
|
||||
billing accounts.
|
||||
"""
|
||||
if DEPLOYMENT_MODE == "oss":
|
||||
return None
|
||||
|
||||
return await mps_service_key_client.ensure_billing_account_v2(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
)
|
||||
|
|
@ -394,6 +394,7 @@ class MPSServiceKeyClient:
|
|||
async def get_credit_ledger(
|
||||
self,
|
||||
organization_id: int,
|
||||
page: int = 1,
|
||||
limit: int = 50,
|
||||
created_by: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
@ -401,7 +402,7 @@ class MPSServiceKeyClient:
|
|||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/ledger",
|
||||
params={"limit": limit},
|
||||
params={"page": page, "limit": limit},
|
||||
headers=self._get_headers(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
|
|
@ -449,6 +450,34 @@ class MPSServiceKeyClient:
|
|||
response=response,
|
||||
)
|
||||
|
||||
async def ensure_billing_account_v2(
|
||||
self,
|
||||
organization_id: int,
|
||||
created_by: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Create or return the MPS v2 billing account for an organization."""
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/balance",
|
||||
headers=self._get_headers(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
logger.error(
|
||||
"Failed to ensure MPS billing account v2: "
|
||||
f"{response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Failed to ensure MPS billing account v2: {response.text}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
async def create_correlation_id(
|
||||
self,
|
||||
*,
|
||||
|
|
|
|||
|
|
@ -1,9 +1,13 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from api.schemas.ai_model_configuration import (
|
||||
DograhManagedAIModelConfiguration,
|
||||
EffectiveAIModelConfiguration,
|
||||
OrganizationAIModelConfigurationResponse,
|
||||
OrganizationAIModelConfigurationV2,
|
||||
compile_ai_model_configuration_v2,
|
||||
)
|
||||
|
|
@ -358,3 +362,98 @@ def test_workflow_model_override_migration_removes_invalid_v1_override_marker():
|
|||
assert changed is True
|
||||
assert "model_overrides" not in migrated
|
||||
assert migrated["ambient_noise_configuration"] == {"enabled": False}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migrate_model_configuration_v2_initializes_hosted_mps_billing(
|
||||
monkeypatch,
|
||||
):
|
||||
from api.routes import organization as organization_routes
|
||||
|
||||
legacy = EffectiveAIModelConfiguration(
|
||||
llm=DograhLLMService(
|
||||
provider="dograh",
|
||||
api_key=["mps-secret"],
|
||||
model="default",
|
||||
),
|
||||
tts=DograhTTSService(
|
||||
provider="dograh",
|
||||
api_key=["mps-secret"],
|
||||
model="default",
|
||||
voice="default",
|
||||
),
|
||||
stt=DograhSTTService(
|
||||
provider="dograh",
|
||||
api_key=["mps-secret"],
|
||||
model="default",
|
||||
),
|
||||
)
|
||||
expected_response = OrganizationAIModelConfigurationResponse(
|
||||
configuration={"version": 2, "mode": "dograh"},
|
||||
effective_configuration={},
|
||||
source="organization_v2",
|
||||
)
|
||||
|
||||
class FakeValidator:
|
||||
async def validate(self, *args, **kwargs):
|
||||
return {"status": [{"model": "all", "message": "ok"}]}
|
||||
|
||||
ensure_billing = AsyncMock(return_value={"billing_mode": "v2"})
|
||||
upsert = AsyncMock()
|
||||
migrate_workflows = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(organization_routes, "DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
organization_routes,
|
||||
"get_organization_ai_model_configuration_v2",
|
||||
AsyncMock(return_value=None),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
organization_routes.db_client,
|
||||
"get_user_configurations",
|
||||
AsyncMock(return_value=legacy),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
organization_routes,
|
||||
"UserConfigurationValidator",
|
||||
lambda: FakeValidator(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
organization_routes,
|
||||
"ensure_hosted_mps_billing_account_v2",
|
||||
ensure_billing,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
organization_routes,
|
||||
"upsert_organization_ai_model_configuration_v2",
|
||||
upsert,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
organization_routes,
|
||||
"migrate_workflow_model_configurations_to_v2",
|
||||
migrate_workflows,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
organization_routes,
|
||||
"_model_configuration_v2_response",
|
||||
AsyncMock(return_value=expected_response),
|
||||
)
|
||||
|
||||
user = SimpleNamespace(
|
||||
id=7,
|
||||
provider_id="provider-123",
|
||||
selected_organization_id=42,
|
||||
)
|
||||
|
||||
response = await organization_routes.migrate_model_configuration_v2(
|
||||
force=False,
|
||||
user=user,
|
||||
)
|
||||
|
||||
ensure_billing.assert_awaited_once_with(42, created_by="provider-123")
|
||||
upsert.assert_awaited_once()
|
||||
migrate_workflows.assert_awaited_once_with(
|
||||
organization_id=42,
|
||||
fallback_user_config=legacy,
|
||||
)
|
||||
assert response == expected_response
|
||||
|
|
|
|||
68
api/tests/test_auth_depends.py
Normal file
68
api/tests/test_auth_depends.py
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.auth import depends as auth_depends
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_initializes_hosted_mps_billing_for_new_org(monkeypatch):
|
||||
stack_user = {
|
||||
"id": "stack-user-1",
|
||||
"selected_team_id": "team-1",
|
||||
"primary_email_verified": False,
|
||||
}
|
||||
user = SimpleNamespace(
|
||||
id=7,
|
||||
email=None,
|
||||
provider_id="stack-user-1",
|
||||
selected_organization_id=None,
|
||||
)
|
||||
organization = SimpleNamespace(id=42)
|
||||
existing_config = SimpleNamespace(llm=object(), tts=None, stt=None)
|
||||
|
||||
ensure_billing = AsyncMock(return_value={"billing_mode": "v2"})
|
||||
|
||||
monkeypatch.setattr(auth_depends, "AUTH_PROVIDER", "stack")
|
||||
monkeypatch.setattr(
|
||||
auth_depends.stackauth,
|
||||
"get_user",
|
||||
AsyncMock(return_value=stack_user),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
auth_depends.db_client,
|
||||
"get_or_create_user_by_provider_id",
|
||||
AsyncMock(return_value=(user, False)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
auth_depends.db_client,
|
||||
"get_or_create_organization_by_provider_id",
|
||||
AsyncMock(return_value=(organization, True)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
auth_depends.db_client,
|
||||
"add_user_to_organization",
|
||||
AsyncMock(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
auth_depends.db_client,
|
||||
"update_user_selected_organization",
|
||||
AsyncMock(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
auth_depends.db_client,
|
||||
"get_user_configurations",
|
||||
AsyncMock(return_value=existing_config),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
auth_depends,
|
||||
"ensure_hosted_mps_billing_account_v2",
|
||||
ensure_billing,
|
||||
)
|
||||
|
||||
result = await auth_depends.get_user(authorization="Bearer token")
|
||||
|
||||
assert result is user
|
||||
assert result.selected_organization_id == 42
|
||||
ensure_billing.assert_awaited_once_with(42, created_by="stack-user-1")
|
||||
|
|
@ -175,6 +175,130 @@ async def test_get_billing_account_status_uses_hosted_org_auth(monkeypatch):
|
|||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_billing_account_v2_uses_balance_endpoint(monkeypatch):
|
||||
calls = []
|
||||
|
||||
class FakeAsyncClient:
|
||||
def __init__(self, timeout):
|
||||
self.timeout = timeout
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
async def get(self, url, headers):
|
||||
calls.append(("GET", url, headers))
|
||||
return _Response(
|
||||
200,
|
||||
{
|
||||
"id": 7,
|
||||
"organization_id": 42,
|
||||
"billing_mode": "v2",
|
||||
"cached_balance_credits": "0.0000",
|
||||
"currency": "USD",
|
||||
},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
|
||||
)
|
||||
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
|
||||
)
|
||||
|
||||
client = MPSServiceKeyClient()
|
||||
|
||||
assert await client.ensure_billing_account_v2(
|
||||
organization_id=42,
|
||||
created_by="provider-123",
|
||||
) == {
|
||||
"id": 7,
|
||||
"organization_id": 42,
|
||||
"billing_mode": "v2",
|
||||
"cached_balance_credits": "0.0000",
|
||||
"currency": "USD",
|
||||
}
|
||||
assert calls == [
|
||||
(
|
||||
"GET",
|
||||
f"{client.base_url}/api/v1/billing/accounts/42/balance",
|
||||
{
|
||||
"Content-Type": "application/json",
|
||||
"X-Secret-Key": "mps-secret",
|
||||
"X-Organization-Id": "42",
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_credit_ledger_sends_page_and_limit(monkeypatch):
|
||||
calls = []
|
||||
|
||||
class FakeAsyncClient:
|
||||
def __init__(self, timeout):
|
||||
self.timeout = timeout
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
async def get(self, url, params, headers):
|
||||
calls.append(("GET", url, params, headers))
|
||||
return _Response(
|
||||
200,
|
||||
{
|
||||
"account": {"organization_id": 42},
|
||||
"ledger_entries": [],
|
||||
"total_count": 0,
|
||||
"page": 3,
|
||||
"limit": 25,
|
||||
"total_pages": 0,
|
||||
},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
|
||||
)
|
||||
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
|
||||
)
|
||||
|
||||
client = MPSServiceKeyClient()
|
||||
|
||||
assert await client.get_credit_ledger(
|
||||
organization_id=42,
|
||||
page=3,
|
||||
limit=25,
|
||||
) == {
|
||||
"account": {"organization_id": 42},
|
||||
"ledger_entries": [],
|
||||
"total_count": 0,
|
||||
"page": 3,
|
||||
"limit": 25,
|
||||
"total_pages": 0,
|
||||
}
|
||||
assert calls == [
|
||||
(
|
||||
"GET",
|
||||
f"{client.base_url}/api/v1/billing/accounts/42/ledger",
|
||||
{"page": 3, "limit": 25},
|
||||
{
|
||||
"Content-Type": "application/json",
|
||||
"X-Secret-Key": "mps-secret",
|
||||
"X-Organization-Id": "42",
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_platform_usage_uses_hosted_secret_auth(monkeypatch):
|
||||
calls = []
|
||||
|
|
|
|||
|
|
@ -31,3 +31,69 @@ async def test_get_mps_billing_account_status_uses_user_provider_id(monkeypatch)
|
|||
organization_id=42,
|
||||
created_by="provider-123",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_billing_credits_pages_v2_ledger(monkeypatch):
|
||||
monkeypatch.setattr(organization_usage, "DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
organization_usage,
|
||||
"_get_mps_billing_account_status",
|
||||
AsyncMock(return_value={"billing_mode": "v2"}),
|
||||
)
|
||||
get_ledger = AsyncMock(
|
||||
return_value={
|
||||
"account": {
|
||||
"id": 7,
|
||||
"organization_id": 42,
|
||||
"billing_mode": "v2",
|
||||
"cached_balance_credits": 250,
|
||||
"currency": "USD",
|
||||
},
|
||||
"ledger_entries": [
|
||||
{
|
||||
"id": 99,
|
||||
"entry_type": "grant",
|
||||
"origin": "account_creation",
|
||||
"credits_delta": 250,
|
||||
"balance_after": 250,
|
||||
"created_at": "2026-06-12T00:00:00Z",
|
||||
}
|
||||
],
|
||||
"total_debits_credits": 75,
|
||||
"total_count": 101,
|
||||
"page": 3,
|
||||
"limit": 25,
|
||||
"total_pages": 5,
|
||||
}
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
organization_usage.mps_service_key_client,
|
||||
"get_credit_ledger",
|
||||
get_ledger,
|
||||
)
|
||||
|
||||
user = SimpleNamespace(
|
||||
provider_id="provider-123",
|
||||
selected_organization_id=42,
|
||||
)
|
||||
|
||||
response = await organization_usage.get_billing_credits(
|
||||
page=3,
|
||||
limit=25,
|
||||
user=user,
|
||||
)
|
||||
|
||||
get_ledger.assert_awaited_once_with(
|
||||
organization_id=42,
|
||||
page=3,
|
||||
limit=25,
|
||||
created_by="provider-123",
|
||||
)
|
||||
assert response.billing_version == "v2"
|
||||
assert response.total_credits_used == 75
|
||||
assert response.total_count == 101
|
||||
assert response.page == 3
|
||||
assert response.limit == 25
|
||||
assert response.total_pages == 5
|
||||
assert response.ledger_entries[0].id == 99
|
||||
|
|
|
|||
|
|
@ -1,7 +1,14 @@
|
|||
"use client";
|
||||
|
||||
import { CircleDollarSign, CreditCard, RefreshCw } from "lucide-react";
|
||||
import {
|
||||
ChevronLeft,
|
||||
ChevronRight,
|
||||
CircleDollarSign,
|
||||
CreditCard,
|
||||
RefreshCw,
|
||||
} from "lucide-react";
|
||||
import Link from "next/link";
|
||||
import { useRouter, useSearchParams } from "next/navigation";
|
||||
import { useCallback, useEffect, useMemo, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
|
||||
|
|
@ -23,6 +30,8 @@ import {
|
|||
import { useAppConfig } from "@/context/AppConfigContext";
|
||||
import { useAuth } from "@/lib/auth";
|
||||
|
||||
const LEDGER_PAGE_SIZE = 50;
|
||||
|
||||
const formatCredits = (value: number | null | undefined) => (
|
||||
(value ?? 0).toLocaleString(undefined, {
|
||||
maximumFractionDigits: 2,
|
||||
|
|
@ -93,13 +102,26 @@ const getRunHref = (entry: MpsCreditLedgerEntryResponse) => {
|
|||
return `/workflow/${entry.workflow_id}/run/${entry.workflow_run_id}`;
|
||||
};
|
||||
|
||||
const getPageFromSearchParams = (
|
||||
searchParams: { get: (name: string) => string | null },
|
||||
) => {
|
||||
const pageParam = searchParams.get("page");
|
||||
const page = pageParam ? Number.parseInt(pageParam, 10) : 1;
|
||||
return Number.isFinite(page) && page > 0 ? page : 1;
|
||||
};
|
||||
|
||||
export default function BillingPage() {
|
||||
const router = useRouter();
|
||||
const searchParams = useSearchParams();
|
||||
const auth = useAuth();
|
||||
const { config } = useAppConfig();
|
||||
const [credits, setCredits] = useState<MpsBillingCreditsResponse | null>(null);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [refreshing, setRefreshing] = useState(false);
|
||||
const [purchasing, setPurchasing] = useState(false);
|
||||
const [currentPage, setCurrentPage] = useState(
|
||||
() => getPageFromSearchParams(searchParams),
|
||||
);
|
||||
|
||||
const isBillingV2 = credits?.billing_version === "v2";
|
||||
const canPurchaseCredits = isBillingV2 && config?.deploymentMode !== "oss";
|
||||
|
|
@ -109,8 +131,14 @@ export default function BillingPage() {
|
|||
const usagePercent = totalQuota > 0 ? Math.min(100, Math.round((usedCredits / totalQuota) * 100)) : 0;
|
||||
|
||||
const ledgerEntries = useMemo(() => credits?.ledger_entries ?? [], [credits?.ledger_entries]);
|
||||
const ledgerPage = credits?.page ?? currentPage;
|
||||
const ledgerTotalCount = credits?.total_count ?? ledgerEntries.length;
|
||||
const ledgerTotalPages = credits?.total_pages ?? 0;
|
||||
|
||||
const fetchCredits = useCallback(async ({ silent = false }: { silent?: boolean } = {}) => {
|
||||
const fetchCredits = useCallback(async (
|
||||
page: number,
|
||||
{ silent = false }: { silent?: boolean } = {},
|
||||
) => {
|
||||
if (auth.loading) {
|
||||
return;
|
||||
}
|
||||
|
|
@ -128,7 +156,7 @@ export default function BillingPage() {
|
|||
|
||||
try {
|
||||
const response = await getBillingCreditsApiV1OrganizationsBillingCreditsGet({
|
||||
query: { limit: 50 },
|
||||
query: { page, limit: LEDGER_PAGE_SIZE },
|
||||
});
|
||||
|
||||
if (response.error) {
|
||||
|
|
@ -146,11 +174,36 @@ export default function BillingPage() {
|
|||
}, [auth.isAuthenticated, auth.loading]);
|
||||
|
||||
useEffect(() => {
|
||||
fetchCredits();
|
||||
}, [fetchCredits]);
|
||||
const nextPage = getPageFromSearchParams(searchParams);
|
||||
setCurrentPage((previousPage) => (
|
||||
previousPage === nextPage ? previousPage : nextPage
|
||||
));
|
||||
}, [searchParams]);
|
||||
|
||||
useEffect(() => {
|
||||
fetchCredits(currentPage);
|
||||
}, [currentPage, fetchCredits]);
|
||||
|
||||
const handleRefresh = () => {
|
||||
fetchCredits({ silent: true });
|
||||
fetchCredits(currentPage, { silent: true });
|
||||
};
|
||||
|
||||
const updateUrlPage = useCallback((page: number) => {
|
||||
const newParams = new URLSearchParams(searchParams.toString());
|
||||
if (page > 1) {
|
||||
newParams.set("page", page.toString());
|
||||
} else {
|
||||
newParams.delete("page");
|
||||
}
|
||||
|
||||
const queryString = newParams.toString();
|
||||
router.push(queryString ? `/billing?${queryString}` : "/billing");
|
||||
}, [router, searchParams]);
|
||||
|
||||
const handlePageChange = (page: number) => {
|
||||
const nextPage = Math.max(1, page);
|
||||
setCurrentPage(nextPage);
|
||||
updateUrlPage(nextPage);
|
||||
};
|
||||
|
||||
const handlePurchaseCredits = async () => {
|
||||
|
|
@ -233,7 +286,7 @@ export default function BillingPage() {
|
|||
</CardHeader>
|
||||
<CardContent>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
{isBillingV2 ? "Recent ledger debit total" : "Current allocation usage"}
|
||||
{isBillingV2 ? "Total ledger debits" : "Current allocation usage"}
|
||||
</p>
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
|
@ -315,6 +368,33 @@ export default function BillingPage() {
|
|||
No ledger entries yet
|
||||
</div>
|
||||
)}
|
||||
{ledgerTotalPages > 1 && (
|
||||
<div className="flex items-center justify-between mt-6">
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Page {ledgerPage} of {ledgerTotalPages} ({ledgerTotalCount} total entries)
|
||||
</p>
|
||||
<div className="flex gap-2">
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() => handlePageChange(ledgerPage - 1)}
|
||||
disabled={ledgerPage <= 1 || loading || refreshing}
|
||||
>
|
||||
<ChevronLeft className="h-4 w-4" />
|
||||
Previous
|
||||
</Button>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() => handlePageChange(ledgerPage + 1)}
|
||||
disabled={ledgerPage >= ledgerTotalPages || loading || refreshing}
|
||||
>
|
||||
Next
|
||||
<ChevronRight className="h-4 w-4" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
) : (
|
||||
|
|
|
|||
|
|
@ -1256,7 +1256,7 @@ export const getMpsCreditsApiV1OrganizationsUsageMpsCreditsGet = <ThrowOnError e
|
|||
/**
|
||||
* Get Billing Credits
|
||||
*
|
||||
* Return legacy MPS credits or v2 billing ledger details for the org.
|
||||
* Return legacy MPS credits or paginated v2 billing ledger details for the org.
|
||||
*/
|
||||
export const getBillingCreditsApiV1OrganizationsBillingCreditsGet = <ThrowOnError extends boolean = false>(options?: Options<GetBillingCreditsApiV1OrganizationsBillingCreditsGetData, ThrowOnError>) => (options?.client ?? client).get<GetBillingCreditsApiV1OrganizationsBillingCreditsGetResponses, GetBillingCreditsApiV1OrganizationsBillingCreditsGetErrors, ThrowOnError>({ url: '/api/v1/organizations/billing/credits', ...options });
|
||||
|
||||
|
|
|
|||
|
|
@ -3138,6 +3138,22 @@ export type MpsBillingCreditsResponse = {
|
|||
* Ledger Entries
|
||||
*/
|
||||
ledger_entries?: Array<MpsCreditLedgerEntryResponse>;
|
||||
/**
|
||||
* Total Count
|
||||
*/
|
||||
total_count?: number;
|
||||
/**
|
||||
* Page
|
||||
*/
|
||||
page?: number;
|
||||
/**
|
||||
* Limit
|
||||
*/
|
||||
limit?: number;
|
||||
/**
|
||||
* Total Pages
|
||||
*/
|
||||
total_pages?: number;
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
@ -11482,6 +11498,10 @@ export type GetBillingCreditsApiV1OrganizationsBillingCreditsGetData = {
|
|||
};
|
||||
path?: never;
|
||||
query?: {
|
||||
/**
|
||||
* Page
|
||||
*/
|
||||
page?: number;
|
||||
/**
|
||||
* Limit
|
||||
*/
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@
|
|||
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
|
||||
import { LANGUAGE_DISPLAY_NAMES } from "@/constants/languages";
|
||||
|
||||
type ModelMode = "dograh" | "byok";
|
||||
type ModelMode = "realtime" | "dograh" | "byok";
|
||||
|
||||
interface DograhDefaults {
|
||||
voices: string[];
|
||||
|
|
@ -125,24 +125,35 @@ function effectiveConfigToLegacyShape(config: Record<string, unknown> | null): R
|
|||
};
|
||||
}
|
||||
|
||||
function emptyByokInitialConfig(): Record<string, unknown> {
|
||||
function emptyByokInitialConfig(isRealtime: boolean): Record<string, unknown> {
|
||||
return {
|
||||
is_realtime: false,
|
||||
is_realtime: isRealtime,
|
||||
};
|
||||
}
|
||||
|
||||
// The v2 editor surfaces realtime ("Speech to Speech") and pipeline (BYOK) as
|
||||
// separate tabs, so each tab gets its own initial config. A tab is pre-filled
|
||||
// only when the saved (or effective) configuration matches that tab's mode;
|
||||
// otherwise it starts empty so the other tab's data does not leak across.
|
||||
function getByokInitialConfig(
|
||||
configuration: Record<string, unknown> | null,
|
||||
effectiveConfiguration: Record<string, unknown> | null,
|
||||
wantRealtime: boolean,
|
||||
): Record<string, unknown> {
|
||||
const byokConfiguration = byokConfigToLegacyShape(configuration);
|
||||
if (byokConfiguration) return byokConfiguration;
|
||||
const matchesTab = (config: Record<string, unknown> | null) =>
|
||||
config ? Boolean(config.is_realtime) === wantRealtime : false;
|
||||
|
||||
if (configuration?.mode === "dograh" || isDograhEffectiveConfig(effectiveConfiguration)) {
|
||||
return emptyByokInitialConfig();
|
||||
const byokConfiguration = byokConfigToLegacyShape(configuration);
|
||||
if (byokConfiguration) {
|
||||
return matchesTab(byokConfiguration) ? byokConfiguration : emptyByokInitialConfig(wantRealtime);
|
||||
}
|
||||
|
||||
return effectiveConfigToLegacyShape(effectiveConfiguration) || emptyByokInitialConfig();
|
||||
if (configuration?.mode === "dograh" || isDograhEffectiveConfig(effectiveConfiguration)) {
|
||||
return emptyByokInitialConfig(wantRealtime);
|
||||
}
|
||||
|
||||
const effective = effectiveConfigToLegacyShape(effectiveConfiguration);
|
||||
return matchesTab(effective) ? (effective as Record<string, unknown>) : emptyByokInitialConfig(wantRealtime);
|
||||
}
|
||||
|
||||
function buildDograhState(
|
||||
|
|
@ -185,10 +196,12 @@ function preferredMode(
|
|||
configuration: Record<string, unknown> | null,
|
||||
effectiveConfiguration: Record<string, unknown> | null,
|
||||
): ModelMode {
|
||||
if (configuration?.mode === "dograh" || configuration?.mode === "byok") {
|
||||
return configuration.mode;
|
||||
if (configuration?.mode === "dograh") return "dograh";
|
||||
if (configuration?.mode === "byok") {
|
||||
return asRecord(configuration.byok)?.mode === "realtime" ? "realtime" : "byok";
|
||||
}
|
||||
return isDograhEffectiveConfig(effectiveConfiguration) ? "dograh" : "byok";
|
||||
if (isDograhEffectiveConfig(effectiveConfiguration)) return "dograh";
|
||||
return Boolean(effectiveConfiguration?.is_realtime) ? "realtime" : "byok";
|
||||
}
|
||||
|
||||
function hasRequiredApiKey(
|
||||
|
|
@ -249,7 +262,8 @@ export function AIModelConfigurationV2Editor({
|
|||
speed: defaults.dograh.defaults.speed,
|
||||
language: defaults.dograh.defaults.language,
|
||||
}));
|
||||
const [byokInitialConfig, setByokInitialConfig] = useState<Record<string, unknown> | null>(null);
|
||||
const [realtimeInitialConfig, setRealtimeInitialConfig] = useState<Record<string, unknown> | null>(null);
|
||||
const [pipelineInitialConfig, setPipelineInitialConfig] = useState<Record<string, unknown> | null>(null);
|
||||
const [isSavingDograh, setIsSavingDograh] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
|
|
@ -258,7 +272,8 @@ export function AIModelConfigurationV2Editor({
|
|||
const rawEffectiveConfiguration = asRecord(effectiveConfiguration);
|
||||
setMode(preferredMode(rawConfiguration, rawEffectiveConfiguration));
|
||||
setDograh(buildDograhState(defaults, rawConfiguration, rawEffectiveConfiguration));
|
||||
setByokInitialConfig(getByokInitialConfig(rawConfiguration, rawEffectiveConfiguration));
|
||||
setRealtimeInitialConfig(getByokInitialConfig(rawConfiguration, rawEffectiveConfiguration, true));
|
||||
setPipelineInitialConfig(getByokInitialConfig(rawConfiguration, rawEffectiveConfiguration, false));
|
||||
}, [configuration, defaults, effectiveConfiguration]);
|
||||
|
||||
const saveDograhConfiguration = async () => {
|
||||
|
|
@ -322,28 +337,30 @@ export function AIModelConfigurationV2Editor({
|
|||
)}
|
||||
|
||||
<Tabs value={mode} onValueChange={(value) => setMode(value as ModelMode)} className="space-y-6">
|
||||
<TabsList className="grid w-full grid-cols-2">
|
||||
<TabsList className="grid w-full grid-cols-3">
|
||||
<TabsTrigger value="realtime">Speech to Speech</TabsTrigger>
|
||||
<TabsTrigger value="dograh">Dograh</TabsTrigger>
|
||||
<TabsTrigger value="byok">BYOK</TabsTrigger>
|
||||
</TabsList>
|
||||
|
||||
<TabsContent value="realtime" className="mt-0">
|
||||
<p className="mb-4 text-sm text-muted-foreground">
|
||||
A single speech-to-speech model handles the conversation in realtime (no separate transcriber or voice). An LLM is still required for variable extraction and QA.
|
||||
</p>
|
||||
<ServiceConfigurationForm
|
||||
key={`realtime-${JSON.stringify(realtimeInitialConfig)}`}
|
||||
mode="global"
|
||||
forceRealtime
|
||||
configurationDefaults={defaultsForByok}
|
||||
initialConfig={realtimeInitialConfig}
|
||||
submitLabel={submitLabel}
|
||||
onSave={saveByokConfiguration}
|
||||
/>
|
||||
</TabsContent>
|
||||
|
||||
<TabsContent value="dograh" className="mt-0">
|
||||
<div className="rounded-lg border p-5">
|
||||
<div className="grid gap-4 sm:grid-cols-2">
|
||||
<div className="space-y-2 sm:col-span-2">
|
||||
<Label htmlFor="dograh-api-key">API Key</Label>
|
||||
<div className="relative">
|
||||
<KeyRound className="pointer-events-none absolute left-3 top-1/2 h-4 w-4 -translate-y-1/2 text-muted-foreground" />
|
||||
<Input
|
||||
id="dograh-api-key"
|
||||
className="pl-9"
|
||||
value={dograh.api_key}
|
||||
onChange={(event) => setDograh({ ...dograh, api_key: event.target.value })}
|
||||
placeholder="Enter API key"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<Label>Voice</Label>
|
||||
<Select value={dograh.voice} onValueChange={(voice) => setDograh({ ...dograh, voice })}>
|
||||
|
|
@ -394,6 +411,20 @@ export function AIModelConfigurationV2Editor({
|
|||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2 sm:col-span-2">
|
||||
<Label htmlFor="dograh-api-key">API Key</Label>
|
||||
<div className="relative">
|
||||
<KeyRound className="pointer-events-none absolute left-3 top-1/2 h-4 w-4 -translate-y-1/2 text-muted-foreground" />
|
||||
<Input
|
||||
id="dograh-api-key"
|
||||
className="pl-9"
|
||||
value={dograh.api_key}
|
||||
onChange={(event) => setDograh({ ...dograh, api_key: event.target.value })}
|
||||
placeholder="Enter API key"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Button type="button" className="mt-6 w-full" onClick={saveDograhConfiguration} disabled={isSavingDograh}>
|
||||
|
|
@ -405,10 +436,11 @@ export function AIModelConfigurationV2Editor({
|
|||
|
||||
<TabsContent value="byok" className="mt-0">
|
||||
<ServiceConfigurationForm
|
||||
key={JSON.stringify(byokInitialConfig)}
|
||||
key={`byok-${JSON.stringify(pipelineInitialConfig)}`}
|
||||
mode="global"
|
||||
forceRealtime={false}
|
||||
configurationDefaults={defaultsForByok}
|
||||
initialConfig={byokInitialConfig}
|
||||
initialConfig={pipelineInitialConfig}
|
||||
submitLabel={submitLabel}
|
||||
onSave={saveByokConfiguration}
|
||||
/>
|
||||
|
|
|
|||
|
|
@ -101,6 +101,13 @@ export interface ServiceConfigurationFormProps {
|
|||
submitLabel?: string;
|
||||
configurationDefaults?: ServiceConfigurationDefaults | null;
|
||||
initialConfig?: Record<string, unknown> | null;
|
||||
/**
|
||||
* When set, locks the realtime/pipeline mode to this value and hides the
|
||||
* in-form toggle. The v2 editor uses this to surface realtime
|
||||
* ("Speech to Speech") and pipeline (BYOK) as separate top-level tabs.
|
||||
* Leave undefined to keep the user-controllable toggle (legacy + overrides).
|
||||
*/
|
||||
forceRealtime?: boolean;
|
||||
}
|
||||
|
||||
function getProviderDisplayName(
|
||||
|
|
@ -130,10 +137,11 @@ export function ServiceConfigurationForm({
|
|||
submitLabel,
|
||||
configurationDefaults,
|
||||
initialConfig,
|
||||
forceRealtime,
|
||||
}: ServiceConfigurationFormProps) {
|
||||
const [apiError, setApiError] = useState<string | null>(null);
|
||||
const [isSaving, setIsSaving] = useState(false);
|
||||
const [isRealtime, setIsRealtime] = useState(false);
|
||||
const [isRealtime, setIsRealtime] = useState(forceRealtime ?? false);
|
||||
const { userConfig } = useUserConfig();
|
||||
const [schemas, setSchemas] = useState<Record<ServiceSegment, Record<string, ProviderSchema>>>({
|
||||
llm: {},
|
||||
|
|
@ -227,9 +235,9 @@ export function ServiceConfigurationForm({
|
|||
realtime: realtimeSchemas,
|
||||
});
|
||||
|
||||
// Restore realtime toggle
|
||||
// Restore realtime toggle (skip when the parent locks the mode)
|
||||
const configData = configSource as Record<string, unknown> | null;
|
||||
if (configData?.is_realtime) {
|
||||
if (forceRealtime === undefined && configData?.is_realtime) {
|
||||
setIsRealtime(true);
|
||||
}
|
||||
|
||||
|
|
@ -867,22 +875,24 @@ export function ServiceConfigurationForm({
|
|||
|
||||
return (
|
||||
<form onSubmit={handleSubmit(onSubmit)}>
|
||||
{/* Realtime toggle */}
|
||||
<div className="flex items-center justify-between mb-4 p-4 border rounded-lg">
|
||||
<div>
|
||||
<Label htmlFor="realtime-toggle" className="text-sm font-medium">
|
||||
Realtime Mode
|
||||
</Label>
|
||||
<p className="text-xs text-muted-foreground mt-0.5">
|
||||
Uses a single speech-to-speech model (no separate STT/TTS). An LLM is still required for variable extraction and QA.
|
||||
</p>
|
||||
{/* Realtime toggle — hidden when the parent locks the mode (v2 tabs) */}
|
||||
{forceRealtime === undefined && (
|
||||
<div className="flex items-center justify-between mb-4 p-4 border rounded-lg">
|
||||
<div>
|
||||
<Label htmlFor="realtime-toggle" className="text-sm font-medium">
|
||||
Realtime Mode
|
||||
</Label>
|
||||
<p className="text-xs text-muted-foreground mt-0.5">
|
||||
Uses a single speech-to-speech model (no separate STT/TTS). An LLM is still required for variable extraction and QA.
|
||||
</p>
|
||||
</div>
|
||||
<Switch
|
||||
id="realtime-toggle"
|
||||
checked={isRealtime}
|
||||
onCheckedChange={setIsRealtime}
|
||||
/>
|
||||
</div>
|
||||
<Switch
|
||||
id="realtime-toggle"
|
||||
checked={isRealtime}
|
||||
onCheckedChange={setIsRealtime}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Card>
|
||||
<CardContent className="pt-6">
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue