mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-22 08:38:13 +02:00
fix: create mps account on migrate to v2
This commit is contained in:
parent
8f241b89d2
commit
724e1d456b
14 changed files with 666 additions and 61 deletions
|
|
@ -1,9 +1,13 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from api.schemas.ai_model_configuration import (
|
||||
DograhManagedAIModelConfiguration,
|
||||
EffectiveAIModelConfiguration,
|
||||
OrganizationAIModelConfigurationResponse,
|
||||
OrganizationAIModelConfigurationV2,
|
||||
compile_ai_model_configuration_v2,
|
||||
)
|
||||
|
|
@ -358,3 +362,98 @@ def test_workflow_model_override_migration_removes_invalid_v1_override_marker():
|
|||
assert changed is True
|
||||
assert "model_overrides" not in migrated
|
||||
assert migrated["ambient_noise_configuration"] == {"enabled": False}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migrate_model_configuration_v2_initializes_hosted_mps_billing(
|
||||
monkeypatch,
|
||||
):
|
||||
from api.routes import organization as organization_routes
|
||||
|
||||
legacy = EffectiveAIModelConfiguration(
|
||||
llm=DograhLLMService(
|
||||
provider="dograh",
|
||||
api_key=["mps-secret"],
|
||||
model="default",
|
||||
),
|
||||
tts=DograhTTSService(
|
||||
provider="dograh",
|
||||
api_key=["mps-secret"],
|
||||
model="default",
|
||||
voice="default",
|
||||
),
|
||||
stt=DograhSTTService(
|
||||
provider="dograh",
|
||||
api_key=["mps-secret"],
|
||||
model="default",
|
||||
),
|
||||
)
|
||||
expected_response = OrganizationAIModelConfigurationResponse(
|
||||
configuration={"version": 2, "mode": "dograh"},
|
||||
effective_configuration={},
|
||||
source="organization_v2",
|
||||
)
|
||||
|
||||
class FakeValidator:
|
||||
async def validate(self, *args, **kwargs):
|
||||
return {"status": [{"model": "all", "message": "ok"}]}
|
||||
|
||||
ensure_billing = AsyncMock(return_value={"billing_mode": "v2"})
|
||||
upsert = AsyncMock()
|
||||
migrate_workflows = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(organization_routes, "DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
organization_routes,
|
||||
"get_organization_ai_model_configuration_v2",
|
||||
AsyncMock(return_value=None),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
organization_routes.db_client,
|
||||
"get_user_configurations",
|
||||
AsyncMock(return_value=legacy),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
organization_routes,
|
||||
"UserConfigurationValidator",
|
||||
lambda: FakeValidator(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
organization_routes,
|
||||
"ensure_hosted_mps_billing_account_v2",
|
||||
ensure_billing,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
organization_routes,
|
||||
"upsert_organization_ai_model_configuration_v2",
|
||||
upsert,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
organization_routes,
|
||||
"migrate_workflow_model_configurations_to_v2",
|
||||
migrate_workflows,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
organization_routes,
|
||||
"_model_configuration_v2_response",
|
||||
AsyncMock(return_value=expected_response),
|
||||
)
|
||||
|
||||
user = SimpleNamespace(
|
||||
id=7,
|
||||
provider_id="provider-123",
|
||||
selected_organization_id=42,
|
||||
)
|
||||
|
||||
response = await organization_routes.migrate_model_configuration_v2(
|
||||
force=False,
|
||||
user=user,
|
||||
)
|
||||
|
||||
ensure_billing.assert_awaited_once_with(42, created_by="provider-123")
|
||||
upsert.assert_awaited_once()
|
||||
migrate_workflows.assert_awaited_once_with(
|
||||
organization_id=42,
|
||||
fallback_user_config=legacy,
|
||||
)
|
||||
assert response == expected_response
|
||||
|
|
|
|||
68
api/tests/test_auth_depends.py
Normal file
68
api/tests/test_auth_depends.py
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.auth import depends as auth_depends
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_initializes_hosted_mps_billing_for_new_org(monkeypatch):
|
||||
stack_user = {
|
||||
"id": "stack-user-1",
|
||||
"selected_team_id": "team-1",
|
||||
"primary_email_verified": False,
|
||||
}
|
||||
user = SimpleNamespace(
|
||||
id=7,
|
||||
email=None,
|
||||
provider_id="stack-user-1",
|
||||
selected_organization_id=None,
|
||||
)
|
||||
organization = SimpleNamespace(id=42)
|
||||
existing_config = SimpleNamespace(llm=object(), tts=None, stt=None)
|
||||
|
||||
ensure_billing = AsyncMock(return_value={"billing_mode": "v2"})
|
||||
|
||||
monkeypatch.setattr(auth_depends, "AUTH_PROVIDER", "stack")
|
||||
monkeypatch.setattr(
|
||||
auth_depends.stackauth,
|
||||
"get_user",
|
||||
AsyncMock(return_value=stack_user),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
auth_depends.db_client,
|
||||
"get_or_create_user_by_provider_id",
|
||||
AsyncMock(return_value=(user, False)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
auth_depends.db_client,
|
||||
"get_or_create_organization_by_provider_id",
|
||||
AsyncMock(return_value=(organization, True)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
auth_depends.db_client,
|
||||
"add_user_to_organization",
|
||||
AsyncMock(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
auth_depends.db_client,
|
||||
"update_user_selected_organization",
|
||||
AsyncMock(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
auth_depends.db_client,
|
||||
"get_user_configurations",
|
||||
AsyncMock(return_value=existing_config),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
auth_depends,
|
||||
"ensure_hosted_mps_billing_account_v2",
|
||||
ensure_billing,
|
||||
)
|
||||
|
||||
result = await auth_depends.get_user(authorization="Bearer token")
|
||||
|
||||
assert result is user
|
||||
assert result.selected_organization_id == 42
|
||||
ensure_billing.assert_awaited_once_with(42, created_by="stack-user-1")
|
||||
|
|
@ -175,6 +175,130 @@ async def test_get_billing_account_status_uses_hosted_org_auth(monkeypatch):
|
|||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_billing_account_v2_uses_balance_endpoint(monkeypatch):
|
||||
calls = []
|
||||
|
||||
class FakeAsyncClient:
|
||||
def __init__(self, timeout):
|
||||
self.timeout = timeout
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
async def get(self, url, headers):
|
||||
calls.append(("GET", url, headers))
|
||||
return _Response(
|
||||
200,
|
||||
{
|
||||
"id": 7,
|
||||
"organization_id": 42,
|
||||
"billing_mode": "v2",
|
||||
"cached_balance_credits": "0.0000",
|
||||
"currency": "USD",
|
||||
},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
|
||||
)
|
||||
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
|
||||
)
|
||||
|
||||
client = MPSServiceKeyClient()
|
||||
|
||||
assert await client.ensure_billing_account_v2(
|
||||
organization_id=42,
|
||||
created_by="provider-123",
|
||||
) == {
|
||||
"id": 7,
|
||||
"organization_id": 42,
|
||||
"billing_mode": "v2",
|
||||
"cached_balance_credits": "0.0000",
|
||||
"currency": "USD",
|
||||
}
|
||||
assert calls == [
|
||||
(
|
||||
"GET",
|
||||
f"{client.base_url}/api/v1/billing/accounts/42/balance",
|
||||
{
|
||||
"Content-Type": "application/json",
|
||||
"X-Secret-Key": "mps-secret",
|
||||
"X-Organization-Id": "42",
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_credit_ledger_sends_page_and_limit(monkeypatch):
|
||||
calls = []
|
||||
|
||||
class FakeAsyncClient:
|
||||
def __init__(self, timeout):
|
||||
self.timeout = timeout
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
async def get(self, url, params, headers):
|
||||
calls.append(("GET", url, params, headers))
|
||||
return _Response(
|
||||
200,
|
||||
{
|
||||
"account": {"organization_id": 42},
|
||||
"ledger_entries": [],
|
||||
"total_count": 0,
|
||||
"page": 3,
|
||||
"limit": 25,
|
||||
"total_pages": 0,
|
||||
},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
|
||||
)
|
||||
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
|
||||
)
|
||||
|
||||
client = MPSServiceKeyClient()
|
||||
|
||||
assert await client.get_credit_ledger(
|
||||
organization_id=42,
|
||||
page=3,
|
||||
limit=25,
|
||||
) == {
|
||||
"account": {"organization_id": 42},
|
||||
"ledger_entries": [],
|
||||
"total_count": 0,
|
||||
"page": 3,
|
||||
"limit": 25,
|
||||
"total_pages": 0,
|
||||
}
|
||||
assert calls == [
|
||||
(
|
||||
"GET",
|
||||
f"{client.base_url}/api/v1/billing/accounts/42/ledger",
|
||||
{"page": 3, "limit": 25},
|
||||
{
|
||||
"Content-Type": "application/json",
|
||||
"X-Secret-Key": "mps-secret",
|
||||
"X-Organization-Id": "42",
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_platform_usage_uses_hosted_secret_auth(monkeypatch):
|
||||
calls = []
|
||||
|
|
|
|||
|
|
@ -31,3 +31,69 @@ async def test_get_mps_billing_account_status_uses_user_provider_id(monkeypatch)
|
|||
organization_id=42,
|
||||
created_by="provider-123",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_billing_credits_pages_v2_ledger(monkeypatch):
|
||||
monkeypatch.setattr(organization_usage, "DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
organization_usage,
|
||||
"_get_mps_billing_account_status",
|
||||
AsyncMock(return_value={"billing_mode": "v2"}),
|
||||
)
|
||||
get_ledger = AsyncMock(
|
||||
return_value={
|
||||
"account": {
|
||||
"id": 7,
|
||||
"organization_id": 42,
|
||||
"billing_mode": "v2",
|
||||
"cached_balance_credits": 250,
|
||||
"currency": "USD",
|
||||
},
|
||||
"ledger_entries": [
|
||||
{
|
||||
"id": 99,
|
||||
"entry_type": "grant",
|
||||
"origin": "account_creation",
|
||||
"credits_delta": 250,
|
||||
"balance_after": 250,
|
||||
"created_at": "2026-06-12T00:00:00Z",
|
||||
}
|
||||
],
|
||||
"total_debits_credits": 75,
|
||||
"total_count": 101,
|
||||
"page": 3,
|
||||
"limit": 25,
|
||||
"total_pages": 5,
|
||||
}
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
organization_usage.mps_service_key_client,
|
||||
"get_credit_ledger",
|
||||
get_ledger,
|
||||
)
|
||||
|
||||
user = SimpleNamespace(
|
||||
provider_id="provider-123",
|
||||
selected_organization_id=42,
|
||||
)
|
||||
|
||||
response = await organization_usage.get_billing_credits(
|
||||
page=3,
|
||||
limit=25,
|
||||
user=user,
|
||||
)
|
||||
|
||||
get_ledger.assert_awaited_once_with(
|
||||
organization_id=42,
|
||||
page=3,
|
||||
limit=25,
|
||||
created_by="provider-123",
|
||||
)
|
||||
assert response.billing_version == "v2"
|
||||
assert response.total_credits_used == 75
|
||||
assert response.total_count == 101
|
||||
assert response.page == 3
|
||||
assert response.limit == 25
|
||||
assert response.total_pages == 5
|
||||
assert response.ledger_entries[0].id == 99
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue