feat: stamp API key into model override at save time to survive global provider change (#362)

* fix: stamp API key into model override at save time to survive global provider change

When a workflow overrides the TTS/LLM/STT provider to match the current
global config, the override dict only stores model/voice fields, not the
API key. If the global config later switches to a different provider, the
override can no longer inherit the API key and calls fail.

Fix: enrich_overrides_with_api_keys() copies the global provider's API
key (and other secret fields) into the override dict at workflow-save
time, making the override self-contained regardless of future global
config changes.

* feat: add test coverage and masking logic

---------

Co-authored-by: Abhishek Kumar <abhishek@a6k.me>
This commit is contained in:
nuthalapativarun 2026-05-27 01:31:14 -07:00 committed by GitHub
parent 8a58b0992d
commit 5b61ad645f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 451 additions and 39 deletions

View file

@ -9,6 +9,7 @@ The rules are simple:
in storage.
"""
import copy
from typing import Any, Dict, Optional
from api.schemas.user_configuration import UserConfiguration
@ -19,6 +20,7 @@ VISIBLE_CHARS = 4 # number of trailing characters to reveal
MASK_CHAR = "*"
MASK_MARKER = "***" # substring that indicates a masked key
SERVICE_SECRET_FIELDS = ("api_key", "credentials", "aws_access_key", "aws_secret_key")
MODEL_OVERRIDE_FIELDS = ("llm", "tts", "stt", "realtime")
def contains_masked_key(value: str | list[str] | None) -> bool:
@ -67,6 +69,12 @@ def mask_key(real_key: str, visible: int = VISIBLE_CHARS) -> str:
return f"{masked_part}{real_key[-visible:]}"
def _mask_secret_value(value: str | list[str]) -> str | list[str]:
if isinstance(value, list):
return [mask_key(k) for k in value]
return mask_key(value)
def is_mask_of(masked: str, real_key: str) -> bool:
"""Return *True* if *masked* equals the mask of *real_key* under the current rules."""
return mask_key(real_key) == masked
@ -117,10 +125,7 @@ def _mask_service(service_cfg: Optional[ServiceConfig]) -> Optional[Dict[str, An
if secret_field not in data or not data[secret_field]:
continue
raw = data[secret_field]
if isinstance(raw, list):
data[secret_field] = [mask_key(k) for k in raw]
else:
data[secret_field] = mask_key(raw)
data[secret_field] = _mask_secret_value(raw)
return data
@ -139,6 +144,28 @@ def mask_user_config(config: UserConfiguration) -> Dict[str, Any]:
}
def mask_workflow_configurations(config: Optional[Dict]) -> Optional[Dict]:
"""Mask secret fields inside workflow-level model overrides for API responses."""
if not config:
return config
masked = copy.deepcopy(config)
model_overrides = masked.get("model_overrides")
if not isinstance(model_overrides, dict):
return masked
for section in MODEL_OVERRIDE_FIELDS:
override = model_overrides.get(section)
if not isinstance(override, dict):
continue
for secret_field in SERVICE_SECRET_FIELDS:
raw = override.get(secret_field)
if raw:
override[secret_field] = _mask_secret_value(raw)
return masked
# ---------------------------------------------------------------------------
# Workflow definition helpers mask / merge node API keys
# ---------------------------------------------------------------------------

View file

@ -4,17 +4,67 @@ from __future__ import annotations
stored, while honouring masked API keys.
"""
import copy
from typing import Dict
from api.schemas.user_configuration import UserConfiguration
from api.services.configuration.masking import (
MODEL_OVERRIDE_FIELDS,
SERVICE_SECRET_FIELDS,
contains_masked_key,
resolve_masked_api_keys,
)
SERVICE_FIELDS = ("llm", "tts", "stt", "embeddings", "realtime")
def _same_provider(incoming_cfg: dict, existing_cfg: dict) -> bool:
return not (
existing_cfg.get("provider") is not None
and incoming_cfg.get("provider") is not None
and incoming_cfg.get("provider") != existing_cfg.get("provider")
)
def _merge_service_secret_fields(
incoming_cfg: dict,
existing_cfg: dict,
*,
preserve_missing: bool,
masked_value_preserves_full_secret: bool = False,
) -> dict:
"""Restore existing real secrets when incoming values are masked.
If ``preserve_missing`` is true, missing incoming secret fields are also
copied from the existing config. User config updates need that behavior;
workflow model overrides leave missing secrets blank so later enrichment can
copy from the current global config.
"""
if not _same_provider(incoming_cfg, existing_cfg):
return incoming_cfg
for secret_field in SERVICE_SECRET_FIELDS:
if secret_field not in existing_cfg:
continue
incoming_secret = incoming_cfg.get(secret_field)
existing_secret = existing_cfg[secret_field]
if incoming_secret is not None:
if contains_masked_key(incoming_secret):
incoming_cfg[secret_field] = (
existing_secret
if masked_value_preserves_full_secret
else resolve_masked_api_keys(
incoming_secret,
existing_secret,
)
)
elif preserve_missing:
incoming_cfg[secret_field] = existing_secret
return incoming_cfg
def merge_user_configurations(
existing: UserConfiguration, incoming_partial: Dict[str, dict]
) -> UserConfiguration:
@ -41,23 +91,12 @@ def merge_user_configurations(
return # nothing to do
old_cfg = merged.get(service_name, {})
provider_changed = (
old_cfg.get("provider") is not None
and incoming_cfg.get("provider") is not None
and incoming_cfg.get("provider") != old_cfg.get("provider")
)
if not provider_changed:
for secret_field in SERVICE_SECRET_FIELDS:
incoming_secret = incoming_cfg.get(secret_field)
if incoming_secret is not None:
if old_cfg and secret_field in old_cfg:
incoming_cfg[secret_field] = resolve_masked_api_keys(
incoming_secret, old_cfg[secret_field]
)
elif secret_field in old_cfg:
incoming_cfg[secret_field] = old_cfg[secret_field]
if old_cfg:
incoming_cfg = _merge_service_secret_fields(
incoming_cfg,
old_cfg,
preserve_missing=True,
)
merged[service_name] = incoming_cfg
@ -75,3 +114,46 @@ def merge_user_configurations(
merged["timezone"] = incoming_partial["timezone"]
return UserConfiguration.model_validate(merged)
def merge_workflow_configuration_secrets(
incoming_config: dict | None,
existing_config: dict | None,
) -> dict | None:
"""Restore persisted workflow override secrets when the client sends masks.
Workflow model overrides intentionally persist real keys so a workflow keeps
running after the global provider changes. API responses mask those keys, so
save requests must merge masked placeholders back to the stored real values.
Unlike user config updates, a missing workflow override secret is not copied
from the existing workflow config. Missing means "copy from current global"
during the later enrichment step.
"""
if not incoming_config or not existing_config:
return incoming_config
merged = copy.deepcopy(incoming_config)
incoming_overrides = merged.get("model_overrides")
existing_overrides = existing_config.get("model_overrides")
if not isinstance(incoming_overrides, dict) or not isinstance(
existing_overrides, dict
):
return merged
for section in MODEL_OVERRIDE_FIELDS:
incoming_section = incoming_overrides.get(section)
existing_section = existing_overrides.get(section)
if not isinstance(incoming_section, dict) or not isinstance(
existing_section, dict
):
continue
incoming_overrides[section] = _merge_service_secret_fields(
incoming_section,
existing_section,
preserve_missing=False,
masked_value_preserves_full_secret=True,
)
return merged

View file

@ -2,6 +2,8 @@
from __future__ import annotations
import copy
from api.schemas.user_configuration import UserConfiguration
from api.services.configuration.registry import (
REGISTRY,
@ -29,6 +31,48 @@ def _build_section_from_override(service_type: ServiceType, override: dict):
return config_cls(**override)
_SECRET_FIELDS = ("api_key", "credentials", "aws_access_key", "aws_secret_key")
def enrich_overrides_with_api_keys(
model_overrides: dict,
user_config: UserConfiguration,
) -> dict:
"""Copy API keys from the global config into model_overrides where missing.
When a workflow override selects the same provider as the current global
config but omits the API key, the override becomes broken if the global
config later switches to a different provider. This function stamps the
global provider's API key (and other secret fields) into the override at
save time so the override is self-contained.
"""
result = copy.deepcopy(model_overrides)
for section_key in _SECTION_MAP:
if section_key not in result:
continue
override = result[section_key]
override_provider = override.get("provider")
if not override_provider:
continue
global_section = getattr(user_config, section_key, None)
if global_section is None:
continue
if getattr(global_section, "provider", None) != override_provider:
continue
for field in _SECRET_FIELDS:
if override.get(field):
continue
if field == "api_key" and hasattr(global_section, "get_all_api_keys"):
all_keys = global_section.get_all_api_keys()
if all_keys:
override[field] = all_keys[0] if len(all_keys) == 1 else all_keys
else:
global_value = getattr(global_section, field, None)
if global_value is not None:
override[field] = global_value
return result
def resolve_effective_config(
user_config: UserConfiguration,
model_overrides: dict | None,