fix: create mps account on migrate to v2

This commit is contained in:
Abhishek Kumar 2026-06-12 14:53:36 +05:30
parent 8f241b89d2
commit 724e1d456b
14 changed files with 666 additions and 61 deletions

View file

@ -1,9 +1,13 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from pydantic import ValidationError
from api.schemas.ai_model_configuration import (
DograhManagedAIModelConfiguration,
EffectiveAIModelConfiguration,
OrganizationAIModelConfigurationResponse,
OrganizationAIModelConfigurationV2,
compile_ai_model_configuration_v2,
)
@ -358,3 +362,98 @@ def test_workflow_model_override_migration_removes_invalid_v1_override_marker():
assert changed is True
assert "model_overrides" not in migrated
assert migrated["ambient_noise_configuration"] == {"enabled": False}
@pytest.mark.asyncio
async def test_migrate_model_configuration_v2_initializes_hosted_mps_billing(
monkeypatch,
):
from api.routes import organization as organization_routes
legacy = EffectiveAIModelConfiguration(
llm=DograhLLMService(
provider="dograh",
api_key=["mps-secret"],
model="default",
),
tts=DograhTTSService(
provider="dograh",
api_key=["mps-secret"],
model="default",
voice="default",
),
stt=DograhSTTService(
provider="dograh",
api_key=["mps-secret"],
model="default",
),
)
expected_response = OrganizationAIModelConfigurationResponse(
configuration={"version": 2, "mode": "dograh"},
effective_configuration={},
source="organization_v2",
)
class FakeValidator:
async def validate(self, *args, **kwargs):
return {"status": [{"model": "all", "message": "ok"}]}
ensure_billing = AsyncMock(return_value={"billing_mode": "v2"})
upsert = AsyncMock()
migrate_workflows = AsyncMock()
monkeypatch.setattr(organization_routes, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
organization_routes,
"get_organization_ai_model_configuration_v2",
AsyncMock(return_value=None),
)
monkeypatch.setattr(
organization_routes.db_client,
"get_user_configurations",
AsyncMock(return_value=legacy),
)
monkeypatch.setattr(
organization_routes,
"UserConfigurationValidator",
lambda: FakeValidator(),
)
monkeypatch.setattr(
organization_routes,
"ensure_hosted_mps_billing_account_v2",
ensure_billing,
)
monkeypatch.setattr(
organization_routes,
"upsert_organization_ai_model_configuration_v2",
upsert,
)
monkeypatch.setattr(
organization_routes,
"migrate_workflow_model_configurations_to_v2",
migrate_workflows,
)
monkeypatch.setattr(
organization_routes,
"_model_configuration_v2_response",
AsyncMock(return_value=expected_response),
)
user = SimpleNamespace(
id=7,
provider_id="provider-123",
selected_organization_id=42,
)
response = await organization_routes.migrate_model_configuration_v2(
force=False,
user=user,
)
ensure_billing.assert_awaited_once_with(42, created_by="provider-123")
upsert.assert_awaited_once()
migrate_workflows.assert_awaited_once_with(
organization_id=42,
fallback_user_config=legacy,
)
assert response == expected_response

View file

@ -0,0 +1,68 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from api.services.auth import depends as auth_depends
@pytest.mark.asyncio
async def test_get_user_initializes_hosted_mps_billing_for_new_org(monkeypatch):
stack_user = {
"id": "stack-user-1",
"selected_team_id": "team-1",
"primary_email_verified": False,
}
user = SimpleNamespace(
id=7,
email=None,
provider_id="stack-user-1",
selected_organization_id=None,
)
organization = SimpleNamespace(id=42)
existing_config = SimpleNamespace(llm=object(), tts=None, stt=None)
ensure_billing = AsyncMock(return_value={"billing_mode": "v2"})
monkeypatch.setattr(auth_depends, "AUTH_PROVIDER", "stack")
monkeypatch.setattr(
auth_depends.stackauth,
"get_user",
AsyncMock(return_value=stack_user),
)
monkeypatch.setattr(
auth_depends.db_client,
"get_or_create_user_by_provider_id",
AsyncMock(return_value=(user, False)),
)
monkeypatch.setattr(
auth_depends.db_client,
"get_or_create_organization_by_provider_id",
AsyncMock(return_value=(organization, True)),
)
monkeypatch.setattr(
auth_depends.db_client,
"add_user_to_organization",
AsyncMock(),
)
monkeypatch.setattr(
auth_depends.db_client,
"update_user_selected_organization",
AsyncMock(),
)
monkeypatch.setattr(
auth_depends.db_client,
"get_user_configurations",
AsyncMock(return_value=existing_config),
)
monkeypatch.setattr(
auth_depends,
"ensure_hosted_mps_billing_account_v2",
ensure_billing,
)
result = await auth_depends.get_user(authorization="Bearer token")
assert result is user
assert result.selected_organization_id == 42
ensure_billing.assert_awaited_once_with(42, created_by="stack-user-1")

View file

@ -175,6 +175,130 @@ async def test_get_billing_account_status_uses_hosted_org_auth(monkeypatch):
]
@pytest.mark.asyncio
async def test_ensure_billing_account_v2_uses_balance_endpoint(monkeypatch):
calls = []
class FakeAsyncClient:
def __init__(self, timeout):
self.timeout = timeout
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def get(self, url, headers):
calls.append(("GET", url, headers))
return _Response(
200,
{
"id": 7,
"organization_id": 42,
"billing_mode": "v2",
"cached_balance_credits": "0.0000",
"currency": "USD",
},
)
monkeypatch.setattr(
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
)
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
)
client = MPSServiceKeyClient()
assert await client.ensure_billing_account_v2(
organization_id=42,
created_by="provider-123",
) == {
"id": 7,
"organization_id": 42,
"billing_mode": "v2",
"cached_balance_credits": "0.0000",
"currency": "USD",
}
assert calls == [
(
"GET",
f"{client.base_url}/api/v1/billing/accounts/42/balance",
{
"Content-Type": "application/json",
"X-Secret-Key": "mps-secret",
"X-Organization-Id": "42",
},
)
]
@pytest.mark.asyncio
async def test_get_credit_ledger_sends_page_and_limit(monkeypatch):
calls = []
class FakeAsyncClient:
def __init__(self, timeout):
self.timeout = timeout
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def get(self, url, params, headers):
calls.append(("GET", url, params, headers))
return _Response(
200,
{
"account": {"organization_id": 42},
"ledger_entries": [],
"total_count": 0,
"page": 3,
"limit": 25,
"total_pages": 0,
},
)
monkeypatch.setattr(
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
)
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
)
client = MPSServiceKeyClient()
assert await client.get_credit_ledger(
organization_id=42,
page=3,
limit=25,
) == {
"account": {"organization_id": 42},
"ledger_entries": [],
"total_count": 0,
"page": 3,
"limit": 25,
"total_pages": 0,
}
assert calls == [
(
"GET",
f"{client.base_url}/api/v1/billing/accounts/42/ledger",
{"page": 3, "limit": 25},
{
"Content-Type": "application/json",
"X-Secret-Key": "mps-secret",
"X-Organization-Id": "42",
},
)
]
@pytest.mark.asyncio
async def test_report_platform_usage_uses_hosted_secret_auth(monkeypatch):
calls = []

View file

@ -31,3 +31,69 @@ async def test_get_mps_billing_account_status_uses_user_provider_id(monkeypatch)
organization_id=42,
created_by="provider-123",
)
@pytest.mark.asyncio
async def test_get_billing_credits_pages_v2_ledger(monkeypatch):
monkeypatch.setattr(organization_usage, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
organization_usage,
"_get_mps_billing_account_status",
AsyncMock(return_value={"billing_mode": "v2"}),
)
get_ledger = AsyncMock(
return_value={
"account": {
"id": 7,
"organization_id": 42,
"billing_mode": "v2",
"cached_balance_credits": 250,
"currency": "USD",
},
"ledger_entries": [
{
"id": 99,
"entry_type": "grant",
"origin": "account_creation",
"credits_delta": 250,
"balance_after": 250,
"created_at": "2026-06-12T00:00:00Z",
}
],
"total_debits_credits": 75,
"total_count": 101,
"page": 3,
"limit": 25,
"total_pages": 5,
}
)
monkeypatch.setattr(
organization_usage.mps_service_key_client,
"get_credit_ledger",
get_ledger,
)
user = SimpleNamespace(
provider_id="provider-123",
selected_organization_id=42,
)
response = await organization_usage.get_billing_credits(
page=3,
limit=25,
user=user,
)
get_ledger.assert_awaited_once_with(
organization_id=42,
page=3,
limit=25,
created_by="provider-123",
)
assert response.billing_version == "v2"
assert response.total_credits_used == 75
assert response.total_count == 101
assert response.page == 3
assert response.limit == 25
assert response.total_pages == 5
assert response.ledger_entries[0].id == 99