dograh/api/services/configuration/masking.py

133 lines
4.4 KiB
Python
Raw Permalink Normal View History

2025-09-09 14:37:32 +05:30
from __future__ import annotations
"""Utilities for masking API keys before they are sent to the client.
The rules are simple:
1. Only expose the last *visible* characters (default 4) of a key.
2. Incoming masked keys are considered a placeholder if they equal the mask of
the already-stored key, we treat them as *unchanged* and keep the real value
in storage.
"""
from typing import Any, Dict, Optional
from api.schemas.user_configuration import UserConfiguration
from api.services.configuration.registry import ServiceConfig
VISIBLE_CHARS = 4 # number of trailing characters to reveal
MASK_CHAR = "*"
def mask_key(real_key: str, visible: int = VISIBLE_CHARS) -> str:
"""Return a masked representation of *real_key*.
Example:
>>> mask_key("sk-1234567890abcdef")
'****************cdef'
"""
if real_key is None:
return ""
if visible <= 0 or visible >= len(real_key):
# mask entire key or nothing to mask edge-cases
return MASK_CHAR * len(real_key)
masked_part = MASK_CHAR * (len(real_key) - visible)
return f"{masked_part}{real_key[-visible:]}"
def is_mask_of(masked: str, real_key: str) -> bool:
"""Return *True* if *masked* equals the mask of *real_key* under the current rules."""
return mask_key(real_key) == masked
# ---------------------------------------------------------------------------
# High-level helpers for UserConfiguration objects
# ---------------------------------------------------------------------------
def _mask_service(service_cfg: Optional[ServiceConfig]) -> Optional[Dict[str, Any]]:
if service_cfg is None:
return None
# Work on a dict copy so we don't mutate original models
data = service_cfg.model_dump()
if "api_key" in data and data["api_key"]:
data["api_key"] = mask_key(data["api_key"])
return data
def mask_user_config(config: UserConfiguration) -> Dict[str, Any]:
"""Return a JSON-serialisable dict of *config* with every api_key masked."""
return {
"llm": _mask_service(config.llm),
"tts": _mask_service(config.tts),
"stt": _mask_service(config.stt),
"embeddings": _mask_service(config.embeddings),
2025-09-09 14:37:32 +05:30
"test_phone_number": config.test_phone_number,
"timezone": config.timezone,
}
# ---------------------------------------------------------------------------
# Workflow definition helpers mask / merge QA-node API keys
# ---------------------------------------------------------------------------
_QA_API_KEY_FIELD = "qa_api_key"
def mask_workflow_definition(workflow_definition: Optional[Dict]) -> Optional[Dict]:
"""Return a *shallow copy* of *workflow_definition* with QA-node API keys masked."""
if not workflow_definition:
return workflow_definition
import copy
masked = copy.deepcopy(workflow_definition)
for node in masked.get("nodes", []):
if node.get("type") != "qa":
continue
data = node.get("data", {})
raw_key = data.get(_QA_API_KEY_FIELD)
if raw_key:
data[_QA_API_KEY_FIELD] = mask_key(raw_key)
return masked
def merge_workflow_api_keys(
incoming_definition: Optional[Dict], existing_definition: Optional[Dict]
) -> Optional[Dict]:
"""Preserve real QA-node API keys when the incoming value is a masked placeholder.
For each QA node in *incoming_definition*, if its ``qa_api_key`` equals
the masked form of the corresponding node in *existing_definition*, the
real key is restored so it is never lost.
"""
if not incoming_definition or not existing_definition:
return incoming_definition
# Build lookup: node-id → data for existing QA nodes
existing_qa: Dict[str, Dict] = {}
for node in existing_definition.get("nodes", []):
if node.get("type") == "qa":
existing_qa[node["id"]] = node.get("data", {})
for node in incoming_definition.get("nodes", []):
if node.get("type") != "qa":
continue
data = node.get("data", {})
incoming_key = data.get(_QA_API_KEY_FIELD)
if not incoming_key:
continue
old_data = existing_qa.get(node["id"])
if not old_data:
continue
old_key = old_data.get(_QA_API_KEY_FIELD, "")
if old_key and is_mask_of(incoming_key, old_key):
data[_QA_API_KEY_FIELD] = old_key
return incoming_definition