feat: show error if quota is exceeded

This commit is contained in:
Abhishek Kumar 2025-11-26 17:40:11 +07:00
parent 3d710cafb1
commit b804daff77
5 changed files with 219 additions and 27 deletions

View file

@ -20,6 +20,8 @@ from loguru import logger
from api.db import db_client
from api.db.models import UserModel
from api.services.auth.depends import get_user_ws
from api.services.configuration.registry import ServiceProviders
from api.services.mps_service_key_client import mps_service_key_client
from api.services.pipecat.run_pipeline import run_pipeline_smallwebrtc
from pipecat.transports.smallwebrtc.connection import SmallWebRTCConnection
from pipecat.utils.context import set_current_run_id
@ -37,6 +39,75 @@ class SignalingManager:
self._connections: Dict[str, WebSocket] = {}
self._peer_connections: Dict[str, SmallWebRTCConnection] = {}
async def _check_dograh_quota(self, user: UserModel) -> tuple[bool, str]:
"""Check if user has sufficient Dograh quota for making a call.
Args:
user_id: The user ID to check quota for
Returns:
Tuple of (has_quota, error_message)
- has_quota: True if user has sufficient quota or not using Dograh
- error_message: Error message if quota check fails, empty string otherwise
"""
try:
# Get user configurations
user_config = await db_client.get_user_configurations(user.id)
# Check if user is using any Dograh service
using_dograh = False
dograh_api_keys = set()
if user_config.llm and user_config.llm.provider == ServiceProviders.DOGRAH:
using_dograh = True
dograh_api_keys.add(user_config.llm.api_key)
if user_config.stt and user_config.stt.provider == ServiceProviders.DOGRAH:
using_dograh = True
dograh_api_keys.add(user_config.stt.api_key)
if user_config.tts and user_config.tts.provider == ServiceProviders.DOGRAH:
using_dograh = True
dograh_api_keys.add(user_config.tts.api_key)
# If not using Dograh, quota check passes
if not using_dograh:
return True, ""
# Check quota for ALL Dograh keys
for api_key in dograh_api_keys:
try:
usage = await mps_service_key_client.check_service_key_usage(
api_key, created_by=user.provider_id
)
remaining = usage.get("remaining_credits", 0.0)
# Require at least $0.10 for a short call
if remaining < 0.10:
logger.warning(
f"Insufficient Dograh credits for key ...{api_key[-8:]}: "
f"${remaining:.2f} remaining"
)
return False, (
"You have exhausted your trial credits."
"Please email founders@dograh.com for additional credits."
)
logger.info(
f"Dograh quota check passed for key ...{api_key[-8:]}: "
f"${remaining:.2f} remaining"
)
except Exception as e:
logger.error(f"Failed to check quota for Dograh key: {str(e)}")
return False, "Could not verify Dograh credits. Please try again."
return True, ""
except Exception as e:
logger.error(f"Error during quota check: {str(e)}")
# On unexpected error, allow the call to proceed
return True, ""
async def handle_websocket(
self,
websocket: WebSocket,
@ -110,6 +181,21 @@ class SignalingManager:
# Set run context for logging
set_current_run_id(workflow_run_id)
# Check Dograh quota before initiating the call
has_quota, error_message = await self._check_dograh_quota(user)
if not has_quota:
# Send error response for quota issues
await ws.send_json(
{
"type": "error",
"payload": {
"error_type": "quota_exceeded",
"message": error_message,
},
}
)
return
if pc_id and pc_id in self._peer_connections:
# Reuse existing connection
logger.info(f"Reusing existing connection for pc_id: {pc_id}")

View file

@ -19,13 +19,33 @@ class MPSServiceKeyClient:
self.base_url = MPS_API_URL
self.timeout = httpx.Timeout(10.0)
def _get_headers(self) -> dict:
"""Get headers for MPS API requests."""
def _get_headers(
self,
organization_id: Optional[int] = None,
created_by: Optional[str] = None,
) -> dict:
"""
Get headers for MPS API requests.
Args:
organization_id: Organization ID for authenticated mode
created_by: User provider ID for OSS mode
Returns:
Dictionary of headers
"""
headers = {"Content-Type": "application/json"}
# Add authentication for non-OSS mode
if DEPLOYMENT_MODE != "oss" and DOGRAH_MPS_SECRET_KEY:
headers["X-Secret-Key"] = DOGRAH_MPS_SECRET_KEY
if DEPLOYMENT_MODE != "oss":
if DOGRAH_MPS_SECRET_KEY:
headers["X-Secret-Key"] = DOGRAH_MPS_SECRET_KEY
if organization_id:
headers["X-Organization-Id"] = str(organization_id)
else:
# OSS mode
if created_by:
headers["X-Created-By"] = created_by
return headers
@ -58,7 +78,7 @@ class MPSServiceKeyClient:
response = await client.post(
f"{self.base_url}/api/v1/service-keys/",
json=request_body,
headers=self._get_headers(),
headers=self._get_headers(organization_id, created_by),
)
if response.status_code == 200:
@ -116,7 +136,7 @@ class MPSServiceKeyClient:
response = await client.get(
f"{self.base_url}/api/v1/service-keys/",
params=params,
headers=self._get_headers(),
headers=self._get_headers(organization_id, created_by),
)
if response.status_code == 200:
@ -152,7 +172,7 @@ class MPSServiceKeyClient:
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.get(
f"{self.base_url}/api/v1/service-keys/{key_id}",
headers=self._get_headers(),
headers=self._get_headers(organization_id, created_by),
)
if response.status_code == 200:
@ -209,7 +229,7 @@ class MPSServiceKeyClient:
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.delete(
f"{self.base_url}/api/v1/service-keys/{key_id}",
headers=self._get_headers(),
headers=self._get_headers(organization_id, created_by),
)
if response.status_code in [200, 204]:
@ -220,6 +240,51 @@ class MPSServiceKeyClient:
)
return False
async def check_service_key_usage(
self,
service_key: str,
organization_id: Optional[int] = None,
created_by: Optional[str] = None,
) -> dict:
"""
Check the usage and quota of a service key.
Args:
service_key: The service key to check usage for
organization_id: Organization ID (for authenticated mode)
created_by: User provider ID (for OSS mode)
Returns:
Dictionary containing:
- total_credits_used: Total credits consumed
- remaining_credits: Credits remaining in quota
Raises:
HTTPException: If the API call fails
"""
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.post(
f"{self.base_url}/api/v1/service-keys/usage",
json={"service_key": service_key},
headers=self._get_headers(organization_id, created_by),
)
if response.status_code == 200:
data = response.json()
return {
"total_credits_used": data.get("total_credits_used", 0.0),
"remaining_credits": data.get("remaining_credits", 0.0),
}
else:
logger.error(
f"Failed to check service key usage: {response.status_code} - {response.text}"
)
raise httpx.HTTPStatusError(
f"Failed to check service key usage: {response.text}",
request=response.request,
response=response,
)
async def call_workflow_api(
self,
call_type: str,
@ -247,17 +312,6 @@ class MPSServiceKeyClient:
Raises:
HTTPException: If the API call fails
"""
headers = {"Content-Type": "application/json"}
# Add secret key authentication
if DEPLOYMENT_MODE != "oss" and DOGRAH_MPS_SECRET_KEY:
headers["X-Secret-Key"] = DOGRAH_MPS_SECRET_KEY
if organization_id:
headers["X-Organization-Id"] = str(organization_id)
elif DEPLOYMENT_MODE == "oss":
if created_by:
headers["X-Created-By"] = created_by
async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client:
response = await client.post(
f"{self.base_url}/api/v1/workflow/create-workflow",
@ -266,7 +320,7 @@ class MPSServiceKeyClient:
"use_case": use_case,
"activity_description": activity_description,
},
headers=headers,
headers=self._get_headers(organization_id, created_by),
)
if response.status_code == 200: