from __future__ import annotations """Utilities for masking API keys before they are sent to the client. The rules are simple: 1. Only expose the last *visible* characters (default 4) of a key. 2. Incoming masked keys are considered a placeholder – if they equal the mask of the already-stored key, we treat them as *unchanged* and keep the real value in storage. """ from typing import Any, Dict, Optional from api.schemas.user_configuration import UserConfiguration from api.services.configuration.registry import ServiceConfig VISIBLE_CHARS = 4 # number of trailing characters to reveal MASK_CHAR = "*" def mask_key(real_key: str, visible: int = VISIBLE_CHARS) -> str: """Return a masked representation of *real_key*. Example: >>> mask_key("sk-1234567890abcdef") '****************cdef' """ if real_key is None: return "" if visible <= 0 or visible >= len(real_key): # mask entire key or nothing to mask – edge-cases return MASK_CHAR * len(real_key) masked_part = MASK_CHAR * (len(real_key) - visible) return f"{masked_part}{real_key[-visible:]}" def is_mask_of(masked: str, real_key: str) -> bool: """Return *True* if *masked* equals the mask of *real_key* under the current rules.""" return mask_key(real_key) == masked # --------------------------------------------------------------------------- # High-level helpers for UserConfiguration objects # --------------------------------------------------------------------------- def _mask_service(service_cfg: Optional[ServiceConfig]) -> Optional[Dict[str, Any]]: if service_cfg is None: return None # Work on a dict copy so we don't mutate original models data = service_cfg.model_dump() if "api_key" in data and data["api_key"]: data["api_key"] = mask_key(data["api_key"]) return data def mask_user_config(config: UserConfiguration) -> Dict[str, Any]: """Return a JSON-serialisable dict of *config* with every api_key masked.""" return { "llm": _mask_service(config.llm), "tts": _mask_service(config.tts), "stt": _mask_service(config.stt), "embeddings": _mask_service(config.embeddings), "test_phone_number": config.test_phone_number, "timezone": config.timezone, } # --------------------------------------------------------------------------- # Workflow definition helpers – mask / merge QA-node API keys # --------------------------------------------------------------------------- _QA_API_KEY_FIELD = "qa_api_key" def mask_workflow_definition(workflow_definition: Optional[Dict]) -> Optional[Dict]: """Return a *shallow copy* of *workflow_definition* with QA-node API keys masked.""" if not workflow_definition: return workflow_definition import copy masked = copy.deepcopy(workflow_definition) for node in masked.get("nodes", []): if node.get("type") != "qa": continue data = node.get("data", {}) raw_key = data.get(_QA_API_KEY_FIELD) if raw_key: data[_QA_API_KEY_FIELD] = mask_key(raw_key) return masked def merge_workflow_api_keys( incoming_definition: Optional[Dict], existing_definition: Optional[Dict] ) -> Optional[Dict]: """Preserve real QA-node API keys when the incoming value is a masked placeholder. For each QA node in *incoming_definition*, if its ``qa_api_key`` equals the masked form of the corresponding node in *existing_definition*, the real key is restored so it is never lost. """ if not incoming_definition or not existing_definition: return incoming_definition # Build lookup: node-id → data for existing QA nodes existing_qa: Dict[str, Dict] = {} for node in existing_definition.get("nodes", []): if node.get("type") == "qa": existing_qa[node["id"]] = node.get("data", {}) for node in incoming_definition.get("nodes", []): if node.get("type") != "qa": continue data = node.get("data", {}) incoming_key = data.get(_QA_API_KEY_FIELD) if not incoming_key: continue old_data = existing_qa.get(node["id"]) if not old_data: continue old_key = old_data.get(_QA_API_KEY_FIELD, "") if old_key and is_mask_of(incoming_key, old_key): data[_QA_API_KEY_FIELD] = old_key return incoming_definition