mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-22 08:38:13 +02:00
feat: allow multiple API keys
This commit is contained in:
parent
162bfabac3
commit
226b4cff91
8 changed files with 172 additions and 133 deletions
|
|
@ -71,10 +71,10 @@ async def get_auth_user(
|
|||
|
||||
|
||||
class UserConfigurationRequestResponseSchema(BaseModel):
|
||||
llm: dict[str, Union[str, float]] | None = None
|
||||
tts: dict[str, Union[str, float]] | None = None
|
||||
stt: dict[str, Union[str, float]] | None = None
|
||||
embeddings: dict[str, Union[str, float]] | None = None
|
||||
llm: dict[str, Union[str, float, list[str]]] | None = None
|
||||
tts: dict[str, Union[str, float, list[str]]] | None = None
|
||||
stt: dict[str, Union[str, float, list[str]]] | None = None
|
||||
embeddings: dict[str, Union[str, float, list[str]]] | None = None
|
||||
test_phone_number: str | None = None
|
||||
timezone: str | None = None
|
||||
organization_pricing: dict[str, Union[float, str, bool]] | None = None
|
||||
|
|
|
|||
|
|
@ -41,6 +41,36 @@ def is_mask_of(masked: str, real_key: str) -> bool:
|
|||
return mask_key(real_key) == masked
|
||||
|
||||
|
||||
def resolve_masked_api_keys(
|
||||
incoming: str | list[str], existing: str | list[str]
|
||||
) -> str | list[str]:
|
||||
"""Resolve masked API keys against existing real keys.
|
||||
|
||||
For each incoming key, if it matches the mask of an existing key, the real
|
||||
key is restored. New (unmasked) keys are kept as-is. This handles adds,
|
||||
removes, reorders, and partial replacements correctly.
|
||||
"""
|
||||
if isinstance(incoming, str) and isinstance(existing, str):
|
||||
return existing if is_mask_of(incoming, existing) else incoming
|
||||
|
||||
existing_list = existing if isinstance(existing, list) else [existing]
|
||||
incoming_list = incoming if isinstance(incoming, list) else [incoming]
|
||||
|
||||
resolved: list[str] = []
|
||||
used: set[int] = set()
|
||||
for key in incoming_list:
|
||||
matched = False
|
||||
for i, real in enumerate(existing_list):
|
||||
if i not in used and is_mask_of(key, real):
|
||||
resolved.append(real)
|
||||
used.add(i)
|
||||
matched = True
|
||||
break
|
||||
if not matched:
|
||||
resolved.append(key)
|
||||
return resolved
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# High-level helpers for UserConfiguration objects
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -53,7 +83,11 @@ def _mask_service(service_cfg: Optional[ServiceConfig]) -> Optional[Dict[str, An
|
|||
# Work on a dict copy so we don't mutate original models
|
||||
data = service_cfg.model_dump()
|
||||
if "api_key" in data and data["api_key"]:
|
||||
data["api_key"] = mask_key(data["api_key"])
|
||||
raw = data["api_key"]
|
||||
if isinstance(raw, list):
|
||||
data["api_key"] = [mask_key(k) for k in raw]
|
||||
else:
|
||||
data["api_key"] = mask_key(raw)
|
||||
return data
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ stored, while honouring masked API keys.
|
|||
from typing import Dict
|
||||
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.services.configuration.masking import is_mask_of
|
||||
from api.services.configuration.masking import resolve_masked_api_keys
|
||||
|
||||
SERVICE_FIELDS = ("llm", "tts", "stt", "embeddings")
|
||||
|
||||
|
|
@ -50,12 +50,10 @@ def merge_user_configurations(
|
|||
if not provider_changed:
|
||||
# conditional preservation of api_key
|
||||
if incoming_api_key is not None:
|
||||
if (
|
||||
old_cfg
|
||||
and "api_key" in old_cfg
|
||||
and is_mask_of(incoming_api_key, old_cfg["api_key"])
|
||||
):
|
||||
incoming_cfg["api_key"] = old_cfg["api_key"]
|
||||
if old_cfg and "api_key" in old_cfg:
|
||||
incoming_cfg["api_key"] = resolve_masked_api_keys(
|
||||
incoming_api_key, old_cfg["api_key"]
|
||||
)
|
||||
else:
|
||||
if "api_key" in old_cfg:
|
||||
incoming_cfg["api_key"] = old_cfg["api_key"]
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
import random
|
||||
from enum import Enum, auto
|
||||
from typing import Annotated, Dict, Literal, Type, TypeVar, Union
|
||||
|
||||
from pydantic import BaseModel, Field, computed_field
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field, computed_field, field_validator
|
||||
|
||||
|
||||
class ServiceType(Enum):
|
||||
|
|
@ -38,7 +40,31 @@ class BaseServiceConfiguration(BaseModel):
|
|||
ServiceProviders.DOGRAH,
|
||||
# ServiceProviders.SARVAM,
|
||||
]
|
||||
api_key: str
|
||||
api_key: str | list[str]
|
||||
|
||||
@field_validator("api_key")
|
||||
@classmethod
|
||||
def validate_api_key(cls, v):
|
||||
if isinstance(v, list) and len(v) == 0:
|
||||
raise ValueError("api_key list must not be empty")
|
||||
return v
|
||||
|
||||
def __getattribute__(self, name: str):
|
||||
if name == "api_key":
|
||||
value = super().__getattribute__(name)
|
||||
if isinstance(value, list):
|
||||
selected_api_key = random.choice(value)
|
||||
logger.debug(f"selected API key {selected_api_key[:-4]}")
|
||||
return selected_api_key
|
||||
return value
|
||||
return super().__getattribute__(name)
|
||||
|
||||
def get_all_api_keys(self) -> list[str]:
|
||||
"""Get all API keys as a list (bypasses random selection)."""
|
||||
value = super().__getattribute__("api_key")
|
||||
if isinstance(value, list):
|
||||
return list(value)
|
||||
return [value]
|
||||
|
||||
|
||||
class BaseLLMConfiguration(BaseServiceConfiguration):
|
||||
|
|
@ -150,7 +176,6 @@ DOGRAH_LLM_MODELS = ["default", "accurate", "fast", "lite", "zen"]
|
|||
class OpenAILLMService(BaseLLMConfiguration):
|
||||
provider: Literal[ServiceProviders.OPENAI] = ServiceProviders.OPENAI
|
||||
model: str = Field(default="gpt-4.1", json_schema_extra={"examples": OPENAI_MODELS})
|
||||
api_key: str
|
||||
|
||||
|
||||
@register_llm
|
||||
|
|
@ -159,7 +184,6 @@ class GoogleLLMService(BaseLLMConfiguration):
|
|||
model: str = Field(
|
||||
default="gemini-2.0-flash", json_schema_extra={"examples": GOOGLE_MODELS}
|
||||
)
|
||||
api_key: str
|
||||
|
||||
|
||||
@register_llm
|
||||
|
|
@ -168,7 +192,6 @@ class GroqLLMService(BaseLLMConfiguration):
|
|||
model: str = Field(
|
||||
default="llama-3.3-70b-versatile", json_schema_extra={"examples": GROQ_MODELS}
|
||||
)
|
||||
api_key: str
|
||||
|
||||
|
||||
@register_llm
|
||||
|
|
@ -177,7 +200,7 @@ class OpenRouterLLMConfiguration(BaseLLMConfiguration):
|
|||
model: str = Field(
|
||||
default="openai/gpt-4.1", json_schema_extra={"examples": OPENROUTER_MODELS}
|
||||
)
|
||||
api_key: str
|
||||
|
||||
base_url: str = Field(default="https://openrouter.ai/api/v1")
|
||||
|
||||
|
||||
|
|
@ -187,7 +210,7 @@ class AzureLLMService(BaseLLMConfiguration):
|
|||
model: str = Field(
|
||||
default="gpt-4.1-mini", json_schema_extra={"examples": AZURE_MODELS}
|
||||
)
|
||||
api_key: str
|
||||
|
||||
endpoint: str
|
||||
|
||||
|
||||
|
|
@ -197,7 +220,6 @@ class DograhLLMService(BaseLLMConfiguration):
|
|||
model: str = Field(
|
||||
default="default", json_schema_extra={"examples": DOGRAH_LLM_MODELS}
|
||||
)
|
||||
api_key: str
|
||||
|
||||
|
||||
LLMConfig = Annotated[
|
||||
|
|
@ -219,7 +241,6 @@ LLMConfig = Annotated[
|
|||
class DeepgramTTSConfiguration(BaseServiceConfiguration):
|
||||
provider: Literal[ServiceProviders.DEEPGRAM] = ServiceProviders.DEEPGRAM
|
||||
voice: str = "aura-2-helena-en"
|
||||
api_key: str
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
|
|
@ -247,7 +268,6 @@ class ElevenlabsTTSConfiguration(BaseServiceConfiguration):
|
|||
default="eleven_flash_v2_5",
|
||||
json_schema_extra={"examples": ELEVENLABS_TTS_MODELS},
|
||||
)
|
||||
api_key: str
|
||||
|
||||
|
||||
OPENAI_TTS_MODELS = ["gpt-4o-mini-tts"]
|
||||
|
|
@ -260,7 +280,6 @@ class OpenAITTSService(BaseTTSConfiguration):
|
|||
default="gpt-4o-mini-tts", json_schema_extra={"examples": OPENAI_TTS_MODELS}
|
||||
)
|
||||
voice: str = "alloy"
|
||||
api_key: str
|
||||
|
||||
|
||||
DOGRAH_TTS_MODELS = ["default"]
|
||||
|
|
@ -274,7 +293,6 @@ class DograhTTSService(BaseTTSConfiguration):
|
|||
)
|
||||
voice: str = "default"
|
||||
speed: float = Field(default=1.0, ge=0.5, le=2.0, description="Speed of the voice")
|
||||
api_key: str
|
||||
|
||||
|
||||
CARTESIA_TTS_MODELS = ["sonic-3"]
|
||||
|
|
@ -287,7 +305,6 @@ class CartesiaTTSConfiguration(BaseTTSConfiguration):
|
|||
default="sonic-3", json_schema_extra={"examples": CARTESIA_TTS_MODELS}
|
||||
)
|
||||
voice: str = Field(default="3faa81ae-d3d8-4ab1-9e44-e50e46d33c30")
|
||||
api_key: str
|
||||
|
||||
|
||||
SARVAM_TTS_MODELS = ["bulbul:v2", "bulbul:v3"]
|
||||
|
|
@ -376,7 +393,6 @@ class SarvamTTSConfiguration(BaseTTSConfiguration):
|
|||
language: str = Field(
|
||||
default="hi-IN", json_schema_extra={"examples": SARVAM_LANGUAGES}
|
||||
)
|
||||
api_key: str
|
||||
|
||||
|
||||
TTSConfig = Annotated[
|
||||
|
|
@ -496,7 +512,6 @@ class DeepgramSTTConfiguration(BaseSTTConfiguration):
|
|||
},
|
||||
},
|
||||
)
|
||||
api_key: str
|
||||
|
||||
|
||||
CARTESIA_STT_MODELS = ["ink-whisper"]
|
||||
|
|
@ -508,7 +523,6 @@ class CartesiaSTTConfiguration(BaseSTTConfiguration):
|
|||
model: str = Field(
|
||||
default="ink-whisper", json_schema_extra={"examples": CARTESIA_STT_MODELS}
|
||||
)
|
||||
api_key: str
|
||||
|
||||
|
||||
OPENAI_STT_MODELS = ["gpt-4o-transcribe"]
|
||||
|
|
@ -520,7 +534,6 @@ class OpenAISTTConfiguration(BaseSTTConfiguration):
|
|||
model: str = Field(
|
||||
default="gpt-4o-transcribe", json_schema_extra={"examples": OPENAI_STT_MODELS}
|
||||
)
|
||||
api_key: str
|
||||
|
||||
|
||||
# Dograh STT Service
|
||||
|
|
@ -537,7 +550,6 @@ class DograhSTTService(BaseSTTConfiguration):
|
|||
language: str = Field(
|
||||
default="multi", json_schema_extra={"examples": DOGRAH_STT_LANGUAGES}
|
||||
)
|
||||
api_key: str
|
||||
|
||||
|
||||
# Sarvam STT Service
|
||||
|
|
@ -553,7 +565,6 @@ class SarvamSTTConfiguration(BaseSTTConfiguration):
|
|||
language: str = Field(
|
||||
default="hi-IN", json_schema_extra={"examples": SARVAM_LANGUAGES}
|
||||
)
|
||||
api_key: str
|
||||
|
||||
|
||||
# Speechmatics STT Service
|
||||
|
|
@ -593,7 +604,6 @@ class SpeechmaticsSTTConfiguration(BaseSTTConfiguration):
|
|||
language: str = Field(
|
||||
default="en", json_schema_extra={"examples": SPEECHMATICS_STT_LANGUAGES}
|
||||
)
|
||||
api_key: str
|
||||
|
||||
|
||||
STTConfig = Annotated[
|
||||
|
|
@ -619,7 +629,6 @@ class OpenAIEmbeddingsConfiguration(BaseEmbeddingsConfiguration):
|
|||
default="text-embedding-3-small",
|
||||
json_schema_extra={"examples": OPENAI_EMBEDDING_MODELS},
|
||||
)
|
||||
api_key: str
|
||||
|
||||
|
||||
OPENROUTER_EMBEDDING_MODELS = ["openai/text-embedding-3-small"]
|
||||
|
|
@ -632,7 +641,7 @@ class OpenRouterEmbeddingsConfiguration(BaseEmbeddingsConfiguration):
|
|||
default="openai/text-embedding-3-small",
|
||||
json_schema_extra={"examples": OPENROUTER_EMBEDDING_MODELS},
|
||||
)
|
||||
api_key: str
|
||||
|
||||
base_url: str = Field(default="https://openrouter.ai/api/v1")
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue