mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
fix: send auth credentials with validate service keys
This commit is contained in:
parent
123114fb94
commit
83f05ab146
9 changed files with 83 additions and 24 deletions
|
|
@ -125,7 +125,11 @@ async def update_user_configurations(
|
|||
|
||||
try:
|
||||
validator = UserConfigurationValidator()
|
||||
await validator.validate(user_configurations)
|
||||
await validator.validate(
|
||||
user_configurations,
|
||||
organization_id=user.selected_organization_id,
|
||||
created_by=user.provider_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=422, detail=e.args[0])
|
||||
|
||||
|
|
@ -163,7 +167,11 @@ async def validate_user_configurations(
|
|||
):
|
||||
validator = UserConfigurationValidator()
|
||||
try:
|
||||
status = await validator.validate(configurations)
|
||||
status = await validator.validate(
|
||||
configurations,
|
||||
organization_id=user.selected_organization_id,
|
||||
created_by=user.provider_id,
|
||||
)
|
||||
await db_client.update_user_configuration_last_validated_at(user.id)
|
||||
return status
|
||||
except ValueError as e:
|
||||
|
|
|
|||
|
|
@ -228,7 +228,7 @@ class SignalingManager:
|
|||
{
|
||||
"type": "error",
|
||||
"payload": {
|
||||
"error_type": "quota_exceeded",
|
||||
"error_type": quota_result.error_code,
|
||||
"message": quota_result.error_message,
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,6 +14,12 @@ from api.schemas.user_configuration import (
|
|||
from api.services.configuration.registry import ServiceConfig, ServiceProviders
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
|
||||
AuthContext = TypedDict(
|
||||
"AuthContext",
|
||||
{"organization_id": Optional[int], "created_by": Optional[str]},
|
||||
total=False,
|
||||
)
|
||||
|
||||
|
||||
class APIKeyStatus(TypedDict):
|
||||
model: str
|
||||
|
|
@ -43,7 +49,16 @@ class UserConfigurationValidator:
|
|||
ServiceProviders.SELF_HOSTED.value: self._check_self_hosted_api_key,
|
||||
}
|
||||
|
||||
async def validate(self, configuration: UserConfiguration) -> APIKeyStatusResponse:
|
||||
async def validate(
|
||||
self,
|
||||
configuration: UserConfiguration,
|
||||
organization_id: Optional[int] = None,
|
||||
created_by: Optional[str] = None,
|
||||
) -> APIKeyStatusResponse:
|
||||
self._auth_context: AuthContext = {
|
||||
"organization_id": organization_id,
|
||||
"created_by": created_by,
|
||||
}
|
||||
status_list = []
|
||||
|
||||
status_list.extend(self._validate_service(configuration.llm, "llm"))
|
||||
|
|
@ -165,7 +180,12 @@ class UserConfigurationValidator:
|
|||
"You provided a Dograh API key (dgr...) instead of a service key. "
|
||||
"Please use a service key (mps...)."
|
||||
)
|
||||
return mps_service_key_client.validate_service_key(api_key)
|
||||
auth = getattr(self, "_auth_context", {})
|
||||
return mps_service_key_client.validate_service_key(
|
||||
api_key,
|
||||
organization_id=auth.get("organization_id"),
|
||||
created_by=auth.get("created_by"),
|
||||
)
|
||||
|
||||
def _check_sarvam_api_key(self, model: str, api_key: str) -> bool:
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -276,7 +276,7 @@ class MPSServiceKeyClient:
|
|||
"remaining_credits": data.get("remaining_credits", 0.0),
|
||||
}
|
||||
else:
|
||||
logger.error(
|
||||
logger.warning(
|
||||
f"Failed to check service key usage: {response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
|
|
@ -416,7 +416,12 @@ class MPSServiceKeyClient:
|
|||
response=response,
|
||||
)
|
||||
|
||||
def validate_service_key(self, service_key: str) -> bool:
|
||||
def validate_service_key(
|
||||
self,
|
||||
service_key: str,
|
||||
organization_id: Optional[int] = None,
|
||||
created_by: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Synchronously validate a Dograh service key by checking usage via MPS.
|
||||
|
||||
|
|
@ -427,7 +432,7 @@ class MPSServiceKeyClient:
|
|||
response = client.post(
|
||||
f"{self.base_url}/api/v1/service-keys/usage",
|
||||
json={"service_key": service_key},
|
||||
headers=self._get_headers(),
|
||||
headers=self._get_headers(organization_id, created_by),
|
||||
)
|
||||
return response.status_code == 200
|
||||
except Exception:
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ class QuotaCheckResult:
|
|||
|
||||
has_quota: bool
|
||||
error_message: str = ""
|
||||
error_code: str = ""
|
||||
|
||||
|
||||
async def check_dograh_quota(user: UserModel) -> QuotaCheckResult:
|
||||
|
|
@ -76,6 +77,7 @@ async def check_dograh_quota(user: UserModel) -> QuotaCheckResult:
|
|||
)
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="quota_exceeded",
|
||||
error_message=(
|
||||
"You have exhausted your trial credits. "
|
||||
"Please email founders@dograh.com for additional Dograh credits "
|
||||
|
|
@ -89,8 +91,16 @@ async def check_dograh_quota(user: UserModel) -> QuotaCheckResult:
|
|||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check quota for Dograh key: {str(e)}")
|
||||
error_str = str(e)
|
||||
if "404" in error_str or "not found" in error_str.lower():
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="invalid_service_key",
|
||||
error_message="You have invalid keys in your model configuration. Please validate the service keys.",
|
||||
)
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="quota_check_failed",
|
||||
error_message="Could not verify Dograh credits. Please try again.",
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,9 +6,12 @@ inline WebSocket media streaming.
|
|||
|
||||
import json
|
||||
import random
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
import aiohttp
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from api.enums import WorkflowRunMode
|
||||
from api.services.telephony.base import (
|
||||
CallInitiationResult,
|
||||
|
|
@ -16,8 +19,6 @@ from api.services.telephony.base import (
|
|||
TelephonyProvider,
|
||||
)
|
||||
from api.utils.common import get_backend_endpoints
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import WebSocket
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ const BrowserCall = ({ workflowId, workflowRunId, initialContextVariables }: {
|
|||
apiKeyModalOpen,
|
||||
setApiKeyModalOpen,
|
||||
apiKeyError,
|
||||
apiKeyErrorCode,
|
||||
workflowConfigError,
|
||||
workflowConfigModalOpen,
|
||||
setWorkflowConfigModalOpen,
|
||||
|
|
@ -91,10 +92,14 @@ const BrowserCall = ({ workflowId, workflowRunId, initialContextVariables }: {
|
|||
};
|
||||
}, [isCompleted, auth.isAuthenticated, workflowId, workflowRunId]);
|
||||
|
||||
const navigateToApiKeys = () => {
|
||||
const navigateToCredits = () => {
|
||||
router.push('/api-keys');
|
||||
};
|
||||
|
||||
const navigateToModelConfig = () => {
|
||||
router.push('/model-configurations');
|
||||
};
|
||||
|
||||
const navigateToWorkflow = () => {
|
||||
router.push(`/workflow/${workflowId}`)
|
||||
}
|
||||
|
|
@ -161,7 +166,9 @@ const BrowserCall = ({ workflowId, workflowRunId, initialContextVariables }: {
|
|||
open={apiKeyModalOpen}
|
||||
onOpenChange={setApiKeyModalOpen}
|
||||
error={apiKeyError}
|
||||
onNavigateToApiKeys={navigateToApiKeys}
|
||||
errorCode={apiKeyErrorCode}
|
||||
onNavigateToCredits={navigateToCredits}
|
||||
onNavigateToModelConfig={navigateToModelConfig}
|
||||
/>
|
||||
|
||||
<WorkflowConfigErrorDialog
|
||||
|
|
|
|||
|
|
@ -7,23 +7,25 @@ interface ApiKeyErrorDialogProps {
|
|||
open: boolean;
|
||||
onOpenChange: (open: boolean) => void;
|
||||
error: string | null;
|
||||
onNavigateToApiKeys: () => void;
|
||||
errorCode: string | null;
|
||||
onNavigateToCredits: () => void;
|
||||
onNavigateToModelConfig: () => void;
|
||||
}
|
||||
|
||||
export const ApiKeyErrorDialog = ({
|
||||
open,
|
||||
onOpenChange,
|
||||
error,
|
||||
onNavigateToApiKeys
|
||||
errorCode,
|
||||
onNavigateToCredits,
|
||||
onNavigateToModelConfig,
|
||||
}: ApiKeyErrorDialogProps) => {
|
||||
// Check if this is a quota error based on the error message
|
||||
const isQuotaError = error?.toLowerCase().includes('insufficient') ||
|
||||
error?.toLowerCase().includes('credits') ||
|
||||
error?.toLowerCase().includes('quota');
|
||||
const isQuotaError = errorCode === 'quota_exceeded';
|
||||
|
||||
const title = isQuotaError ? "Insufficient Credits" : "API Configuration Error";
|
||||
const icon = isQuotaError ? <CreditCard className="h-5 w-5 text-orange-500" /> : <Key className="h-5 w-5 text-red-500" />;
|
||||
const buttonText = isQuotaError ? "Add Credits" : "Go to API Keys Settings";
|
||||
const buttonText = isQuotaError ? "Add Credits" : "Go to Model Configurations";
|
||||
const onNavigate = isQuotaError ? onNavigateToCredits : onNavigateToModelConfig;
|
||||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={onOpenChange}>
|
||||
|
|
@ -51,7 +53,7 @@ export const ApiKeyErrorDialog = ({
|
|||
<Button variant="outline" onClick={() => onOpenChange(false)}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button onClick={onNavigateToApiKeys}>
|
||||
<Button onClick={onNavigate}>
|
||||
{buttonText}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ export const useWebSocketRTC = ({ workflowId, workflowRunId, accessToken, initia
|
|||
const [isCompleted, setIsCompleted] = useState(false);
|
||||
const [apiKeyModalOpen, setApiKeyModalOpen] = useState(false);
|
||||
const [apiKeyError, setApiKeyError] = useState<string | null>(null);
|
||||
const [apiKeyErrorCode, setApiKeyErrorCode] = useState<string | null>(null);
|
||||
const [workflowConfigModalOpen, setWorkflowConfigModalOpen] = useState(false);
|
||||
const [workflowConfigError, setWorkflowConfigError] = useState<string | null>(null);
|
||||
const [isStarting, setIsStarting] = useState(false);
|
||||
|
|
@ -264,12 +265,15 @@ export const useWebSocketRTC = ({ workflowId, workflowRunId, accessToken, initia
|
|||
break;
|
||||
|
||||
case 'error':
|
||||
// Check if this is a quota exceeded error
|
||||
if (message.payload?.error_type === 'quota_exceeded') {
|
||||
// Check if this is a quota/service key error
|
||||
if (message.payload?.error_type === 'quota_exceeded' ||
|
||||
message.payload?.error_type === 'invalid_service_key' ||
|
||||
message.payload?.error_type === 'quota_check_failed') {
|
||||
// Log as info since it's a handled business logic case
|
||||
logger.info('Quota exceeded, showing user dialog:', message.payload.message);
|
||||
logger.info('Quota/service key error, showing user dialog:', message.payload.message);
|
||||
|
||||
// Set error state for display
|
||||
setApiKeyErrorCode(message.payload.error_type);
|
||||
setApiKeyError(message.payload.message || 'Service quota exceeded');
|
||||
setApiKeyModalOpen(true);
|
||||
|
||||
|
|
@ -545,6 +549,7 @@ export const useWebSocketRTC = ({ workflowId, workflowRunId, accessToken, initia
|
|||
|
||||
if (response.error) {
|
||||
setApiKeyModalOpen(true);
|
||||
setApiKeyErrorCode('invalid_api_key');
|
||||
let msg = 'API Key Error';
|
||||
const detail = (response.error as unknown as { detail?: { errors: { model: string; message: string }[] } }).detail;
|
||||
if (Array.isArray(detail)) {
|
||||
|
|
@ -685,6 +690,7 @@ export const useWebSocketRTC = ({ workflowId, workflowRunId, accessToken, initia
|
|||
apiKeyModalOpen,
|
||||
setApiKeyModalOpen,
|
||||
apiKeyError,
|
||||
apiKeyErrorCode,
|
||||
workflowConfigError,
|
||||
workflowConfigModalOpen,
|
||||
setWorkflowConfigModalOpen,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue