feat: allow multiple API keys

This commit is contained in:
Abhishek Kumar 2026-03-09 22:34:27 +05:30
parent 162bfabac3
commit 226b4cff91
8 changed files with 172 additions and 133 deletions

View file

@ -71,10 +71,10 @@ async def get_auth_user(
class UserConfigurationRequestResponseSchema(BaseModel):
llm: dict[str, Union[str, float]] | None = None
tts: dict[str, Union[str, float]] | None = None
stt: dict[str, Union[str, float]] | None = None
embeddings: dict[str, Union[str, float]] | None = None
llm: dict[str, Union[str, float, list[str]]] | None = None
tts: dict[str, Union[str, float, list[str]]] | None = None
stt: dict[str, Union[str, float, list[str]]] | None = None
embeddings: dict[str, Union[str, float, list[str]]] | None = None
test_phone_number: str | None = None
timezone: str | None = None
organization_pricing: dict[str, Union[float, str, bool]] | None = None

View file

@ -41,6 +41,36 @@ def is_mask_of(masked: str, real_key: str) -> bool:
return mask_key(real_key) == masked
def resolve_masked_api_keys(
incoming: str | list[str], existing: str | list[str]
) -> str | list[str]:
"""Resolve masked API keys against existing real keys.
For each incoming key, if it matches the mask of an existing key, the real
key is restored. New (unmasked) keys are kept as-is. This handles adds,
removes, reorders, and partial replacements correctly.
"""
if isinstance(incoming, str) and isinstance(existing, str):
return existing if is_mask_of(incoming, existing) else incoming
existing_list = existing if isinstance(existing, list) else [existing]
incoming_list = incoming if isinstance(incoming, list) else [incoming]
resolved: list[str] = []
used: set[int] = set()
for key in incoming_list:
matched = False
for i, real in enumerate(existing_list):
if i not in used and is_mask_of(key, real):
resolved.append(real)
used.add(i)
matched = True
break
if not matched:
resolved.append(key)
return resolved
# ---------------------------------------------------------------------------
# High-level helpers for UserConfiguration objects
# ---------------------------------------------------------------------------
@ -53,7 +83,11 @@ def _mask_service(service_cfg: Optional[ServiceConfig]) -> Optional[Dict[str, An
# Work on a dict copy so we don't mutate original models
data = service_cfg.model_dump()
if "api_key" in data and data["api_key"]:
data["api_key"] = mask_key(data["api_key"])
raw = data["api_key"]
if isinstance(raw, list):
data["api_key"] = [mask_key(k) for k in raw]
else:
data["api_key"] = mask_key(raw)
return data

View file

@ -7,7 +7,7 @@ stored, while honouring masked API keys.
from typing import Dict
from api.schemas.user_configuration import UserConfiguration
from api.services.configuration.masking import is_mask_of
from api.services.configuration.masking import resolve_masked_api_keys
SERVICE_FIELDS = ("llm", "tts", "stt", "embeddings")
@ -50,12 +50,10 @@ def merge_user_configurations(
if not provider_changed:
# conditional preservation of api_key
if incoming_api_key is not None:
if (
old_cfg
and "api_key" in old_cfg
and is_mask_of(incoming_api_key, old_cfg["api_key"])
):
incoming_cfg["api_key"] = old_cfg["api_key"]
if old_cfg and "api_key" in old_cfg:
incoming_cfg["api_key"] = resolve_masked_api_keys(
incoming_api_key, old_cfg["api_key"]
)
else:
if "api_key" in old_cfg:
incoming_cfg["api_key"] = old_cfg["api_key"]

View file

@ -1,7 +1,9 @@
import random
from enum import Enum, auto
from typing import Annotated, Dict, Literal, Type, TypeVar, Union
from pydantic import BaseModel, Field, computed_field
from loguru import logger
from pydantic import BaseModel, Field, computed_field, field_validator
class ServiceType(Enum):
@ -38,7 +40,31 @@ class BaseServiceConfiguration(BaseModel):
ServiceProviders.DOGRAH,
# ServiceProviders.SARVAM,
]
api_key: str
api_key: str | list[str]
@field_validator("api_key")
@classmethod
def validate_api_key(cls, v):
if isinstance(v, list) and len(v) == 0:
raise ValueError("api_key list must not be empty")
return v
def __getattribute__(self, name: str):
if name == "api_key":
value = super().__getattribute__(name)
if isinstance(value, list):
selected_api_key = random.choice(value)
logger.debug(f"selected API key {selected_api_key[:-4]}")
return selected_api_key
return value
return super().__getattribute__(name)
def get_all_api_keys(self) -> list[str]:
"""Get all API keys as a list (bypasses random selection)."""
value = super().__getattribute__("api_key")
if isinstance(value, list):
return list(value)
return [value]
class BaseLLMConfiguration(BaseServiceConfiguration):
@ -150,7 +176,6 @@ DOGRAH_LLM_MODELS = ["default", "accurate", "fast", "lite", "zen"]
class OpenAILLMService(BaseLLMConfiguration):
provider: Literal[ServiceProviders.OPENAI] = ServiceProviders.OPENAI
model: str = Field(default="gpt-4.1", json_schema_extra={"examples": OPENAI_MODELS})
api_key: str
@register_llm
@ -159,7 +184,6 @@ class GoogleLLMService(BaseLLMConfiguration):
model: str = Field(
default="gemini-2.0-flash", json_schema_extra={"examples": GOOGLE_MODELS}
)
api_key: str
@register_llm
@ -168,7 +192,6 @@ class GroqLLMService(BaseLLMConfiguration):
model: str = Field(
default="llama-3.3-70b-versatile", json_schema_extra={"examples": GROQ_MODELS}
)
api_key: str
@register_llm
@ -177,7 +200,7 @@ class OpenRouterLLMConfiguration(BaseLLMConfiguration):
model: str = Field(
default="openai/gpt-4.1", json_schema_extra={"examples": OPENROUTER_MODELS}
)
api_key: str
base_url: str = Field(default="https://openrouter.ai/api/v1")
@ -187,7 +210,7 @@ class AzureLLMService(BaseLLMConfiguration):
model: str = Field(
default="gpt-4.1-mini", json_schema_extra={"examples": AZURE_MODELS}
)
api_key: str
endpoint: str
@ -197,7 +220,6 @@ class DograhLLMService(BaseLLMConfiguration):
model: str = Field(
default="default", json_schema_extra={"examples": DOGRAH_LLM_MODELS}
)
api_key: str
LLMConfig = Annotated[
@ -219,7 +241,6 @@ LLMConfig = Annotated[
class DeepgramTTSConfiguration(BaseServiceConfiguration):
provider: Literal[ServiceProviders.DEEPGRAM] = ServiceProviders.DEEPGRAM
voice: str = "aura-2-helena-en"
api_key: str
@computed_field
@property
@ -247,7 +268,6 @@ class ElevenlabsTTSConfiguration(BaseServiceConfiguration):
default="eleven_flash_v2_5",
json_schema_extra={"examples": ELEVENLABS_TTS_MODELS},
)
api_key: str
OPENAI_TTS_MODELS = ["gpt-4o-mini-tts"]
@ -260,7 +280,6 @@ class OpenAITTSService(BaseTTSConfiguration):
default="gpt-4o-mini-tts", json_schema_extra={"examples": OPENAI_TTS_MODELS}
)
voice: str = "alloy"
api_key: str
DOGRAH_TTS_MODELS = ["default"]
@ -274,7 +293,6 @@ class DograhTTSService(BaseTTSConfiguration):
)
voice: str = "default"
speed: float = Field(default=1.0, ge=0.5, le=2.0, description="Speed of the voice")
api_key: str
CARTESIA_TTS_MODELS = ["sonic-3"]
@ -287,7 +305,6 @@ class CartesiaTTSConfiguration(BaseTTSConfiguration):
default="sonic-3", json_schema_extra={"examples": CARTESIA_TTS_MODELS}
)
voice: str = Field(default="3faa81ae-d3d8-4ab1-9e44-e50e46d33c30")
api_key: str
SARVAM_TTS_MODELS = ["bulbul:v2", "bulbul:v3"]
@ -376,7 +393,6 @@ class SarvamTTSConfiguration(BaseTTSConfiguration):
language: str = Field(
default="hi-IN", json_schema_extra={"examples": SARVAM_LANGUAGES}
)
api_key: str
TTSConfig = Annotated[
@ -496,7 +512,6 @@ class DeepgramSTTConfiguration(BaseSTTConfiguration):
},
},
)
api_key: str
CARTESIA_STT_MODELS = ["ink-whisper"]
@ -508,7 +523,6 @@ class CartesiaSTTConfiguration(BaseSTTConfiguration):
model: str = Field(
default="ink-whisper", json_schema_extra={"examples": CARTESIA_STT_MODELS}
)
api_key: str
OPENAI_STT_MODELS = ["gpt-4o-transcribe"]
@ -520,7 +534,6 @@ class OpenAISTTConfiguration(BaseSTTConfiguration):
model: str = Field(
default="gpt-4o-transcribe", json_schema_extra={"examples": OPENAI_STT_MODELS}
)
api_key: str
# Dograh STT Service
@ -537,7 +550,6 @@ class DograhSTTService(BaseSTTConfiguration):
language: str = Field(
default="multi", json_schema_extra={"examples": DOGRAH_STT_LANGUAGES}
)
api_key: str
# Sarvam STT Service
@ -553,7 +565,6 @@ class SarvamSTTConfiguration(BaseSTTConfiguration):
language: str = Field(
default="hi-IN", json_schema_extra={"examples": SARVAM_LANGUAGES}
)
api_key: str
# Speechmatics STT Service
@ -593,7 +604,6 @@ class SpeechmaticsSTTConfiguration(BaseSTTConfiguration):
language: str = Field(
default="en", json_schema_extra={"examples": SPEECHMATICS_STT_LANGUAGES}
)
api_key: str
STTConfig = Annotated[
@ -619,7 +629,6 @@ class OpenAIEmbeddingsConfiguration(BaseEmbeddingsConfiguration):
default="text-embedding-3-small",
json_schema_extra={"examples": OPENAI_EMBEDDING_MODELS},
)
api_key: str
OPENROUTER_EMBEDDING_MODELS = ["openai/text-embedding-3-small"]
@ -632,7 +641,7 @@ class OpenRouterEmbeddingsConfiguration(BaseEmbeddingsConfiguration):
default="openai/text-embedding-3-small",
json_schema_extra={"examples": OPENROUTER_EMBEDDING_MODELS},
)
api_key: str
base_url: str = Field(default="https://openrouter.ai/api/v1")