mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-25 08:48:13 +02:00
fix: create mps account on migrate to v2
This commit is contained in:
parent
8f241b89d2
commit
724e1d456b
14 changed files with 666 additions and 61 deletions
|
|
@ -12,6 +12,7 @@ from api.enums import PostHogEvent
|
|||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.auth.stack_auth import stackauth
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.mps_billing import ensure_hosted_mps_billing_account_v2
|
||||
from api.services.posthog_client import capture_event
|
||||
from api.utils.auth import decode_jwt_token
|
||||
|
||||
|
|
@ -110,6 +111,19 @@ async def get_user(
|
|||
# This prevents race conditions where multiple concurrent requests
|
||||
# might try to create configurations
|
||||
if org_was_created:
|
||||
try:
|
||||
await ensure_hosted_mps_billing_account_v2(
|
||||
organization.id,
|
||||
created_by=str(stack_user["id"]),
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to initialize hosted MPS billing account for "
|
||||
"organization {}",
|
||||
organization.id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
existing_cfg = await db_client.get_user_configurations(user_model.id)
|
||||
if not (existing_cfg.llm or existing_cfg.tts or existing_cfg.stt):
|
||||
mps_config = await create_user_configuration_with_mps_key(
|
||||
|
|
@ -232,7 +246,7 @@ async def create_user_configuration_with_mps_key(
|
|||
response = await client.post(
|
||||
f"{MPS_API_URL}/api/v1/service-keys/",
|
||||
json={
|
||||
"name": f"Default Dograh Model Service Key",
|
||||
"name": "Default Dograh Model Service Key",
|
||||
"description": "Auto-generated key for OSS user",
|
||||
"expires_in_days": 7, # Short-lived for OSS
|
||||
"created_by": user_provider_id,
|
||||
|
|
@ -250,7 +264,7 @@ async def create_user_configuration_with_mps_key(
|
|||
response = await client.post(
|
||||
f"{MPS_API_URL}/api/v1/service-keys/",
|
||||
json={
|
||||
"name": f"Default Dograh Model Service Key",
|
||||
"name": "Default Dograh Model Service Key",
|
||||
"description": f"Auto-generated key for organization {organization_id}",
|
||||
"organization_id": organization_id,
|
||||
"expires_in_days": 90, # Longer-lived for authenticated users
|
||||
|
|
|
|||
23
api/services/mps_billing.py
Normal file
23
api/services/mps_billing.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
from typing import Optional
|
||||
|
||||
from api.constants import DEPLOYMENT_MODE
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
|
||||
|
||||
async def ensure_hosted_mps_billing_account_v2(
|
||||
organization_id: int,
|
||||
*,
|
||||
created_by: Optional[str] = None,
|
||||
) -> Optional[dict]:
|
||||
"""Ensure hosted orgs have an MPS billing v2 account.
|
||||
|
||||
OSS deployments use legacy per-key quota accounting and do not create MPS
|
||||
billing accounts.
|
||||
"""
|
||||
if DEPLOYMENT_MODE == "oss":
|
||||
return None
|
||||
|
||||
return await mps_service_key_client.ensure_billing_account_v2(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
)
|
||||
|
|
@ -394,6 +394,7 @@ class MPSServiceKeyClient:
|
|||
async def get_credit_ledger(
|
||||
self,
|
||||
organization_id: int,
|
||||
page: int = 1,
|
||||
limit: int = 50,
|
||||
created_by: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
@ -401,7 +402,7 @@ class MPSServiceKeyClient:
|
|||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/ledger",
|
||||
params={"limit": limit},
|
||||
params={"page": page, "limit": limit},
|
||||
headers=self._get_headers(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
|
|
@ -449,6 +450,34 @@ class MPSServiceKeyClient:
|
|||
response=response,
|
||||
)
|
||||
|
||||
async def ensure_billing_account_v2(
|
||||
self,
|
||||
organization_id: int,
|
||||
created_by: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Create or return the MPS v2 billing account for an organization."""
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/balance",
|
||||
headers=self._get_headers(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
logger.error(
|
||||
"Failed to ensure MPS billing account v2: "
|
||||
f"{response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Failed to ensure MPS billing account v2: {response.text}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
async def create_correlation_id(
|
||||
self,
|
||||
*,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue