mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-22 08:38:13 +02:00
feat: stamp API key into model override at save time to survive global provider change (#362)
* fix: stamp API key into model override at save time to survive global provider change When a workflow overrides the TTS/LLM/STT provider to match the current global config, the override dict only stores model/voice fields, not the API key. If the global config later switches to a different provider, the override can no longer inherit the API key and calls fail. Fix: enrich_overrides_with_api_keys() copies the global provider's API key (and other secret fields) into the override dict at workflow-save time, making the override self-contained regardless of future global config changes. * feat: add test coverage and masking logic --------- Co-authored-by: Abhishek Kumar <abhishek@a6k.me>
This commit is contained in:
parent
8a58b0992d
commit
5b61ad645f
6 changed files with 451 additions and 39 deletions
|
|
@ -9,6 +9,7 @@ The rules are simple:
|
|||
in storage.
|
||||
"""
|
||||
|
||||
import copy
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
|
|
@ -19,6 +20,7 @@ VISIBLE_CHARS = 4 # number of trailing characters to reveal
|
|||
MASK_CHAR = "*"
|
||||
MASK_MARKER = "***" # substring that indicates a masked key
|
||||
SERVICE_SECRET_FIELDS = ("api_key", "credentials", "aws_access_key", "aws_secret_key")
|
||||
MODEL_OVERRIDE_FIELDS = ("llm", "tts", "stt", "realtime")
|
||||
|
||||
|
||||
def contains_masked_key(value: str | list[str] | None) -> bool:
|
||||
|
|
@ -67,6 +69,12 @@ def mask_key(real_key: str, visible: int = VISIBLE_CHARS) -> str:
|
|||
return f"{masked_part}{real_key[-visible:]}"
|
||||
|
||||
|
||||
def _mask_secret_value(value: str | list[str]) -> str | list[str]:
|
||||
if isinstance(value, list):
|
||||
return [mask_key(k) for k in value]
|
||||
return mask_key(value)
|
||||
|
||||
|
||||
def is_mask_of(masked: str, real_key: str) -> bool:
|
||||
"""Return *True* if *masked* equals the mask of *real_key* under the current rules."""
|
||||
return mask_key(real_key) == masked
|
||||
|
|
@ -117,10 +125,7 @@ def _mask_service(service_cfg: Optional[ServiceConfig]) -> Optional[Dict[str, An
|
|||
if secret_field not in data or not data[secret_field]:
|
||||
continue
|
||||
raw = data[secret_field]
|
||||
if isinstance(raw, list):
|
||||
data[secret_field] = [mask_key(k) for k in raw]
|
||||
else:
|
||||
data[secret_field] = mask_key(raw)
|
||||
data[secret_field] = _mask_secret_value(raw)
|
||||
return data
|
||||
|
||||
|
||||
|
|
@ -139,6 +144,28 @@ def mask_user_config(config: UserConfiguration) -> Dict[str, Any]:
|
|||
}
|
||||
|
||||
|
||||
def mask_workflow_configurations(config: Optional[Dict]) -> Optional[Dict]:
|
||||
"""Mask secret fields inside workflow-level model overrides for API responses."""
|
||||
if not config:
|
||||
return config
|
||||
|
||||
masked = copy.deepcopy(config)
|
||||
model_overrides = masked.get("model_overrides")
|
||||
if not isinstance(model_overrides, dict):
|
||||
return masked
|
||||
|
||||
for section in MODEL_OVERRIDE_FIELDS:
|
||||
override = model_overrides.get(section)
|
||||
if not isinstance(override, dict):
|
||||
continue
|
||||
for secret_field in SERVICE_SECRET_FIELDS:
|
||||
raw = override.get(secret_field)
|
||||
if raw:
|
||||
override[secret_field] = _mask_secret_value(raw)
|
||||
|
||||
return masked
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Workflow definition helpers – mask / merge node API keys
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -4,17 +4,67 @@ from __future__ import annotations
|
|||
stored, while honouring masked API keys.
|
||||
"""
|
||||
|
||||
import copy
|
||||
from typing import Dict
|
||||
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.services.configuration.masking import (
|
||||
MODEL_OVERRIDE_FIELDS,
|
||||
SERVICE_SECRET_FIELDS,
|
||||
contains_masked_key,
|
||||
resolve_masked_api_keys,
|
||||
)
|
||||
|
||||
SERVICE_FIELDS = ("llm", "tts", "stt", "embeddings", "realtime")
|
||||
|
||||
|
||||
def _same_provider(incoming_cfg: dict, existing_cfg: dict) -> bool:
|
||||
return not (
|
||||
existing_cfg.get("provider") is not None
|
||||
and incoming_cfg.get("provider") is not None
|
||||
and incoming_cfg.get("provider") != existing_cfg.get("provider")
|
||||
)
|
||||
|
||||
|
||||
def _merge_service_secret_fields(
|
||||
incoming_cfg: dict,
|
||||
existing_cfg: dict,
|
||||
*,
|
||||
preserve_missing: bool,
|
||||
masked_value_preserves_full_secret: bool = False,
|
||||
) -> dict:
|
||||
"""Restore existing real secrets when incoming values are masked.
|
||||
|
||||
If ``preserve_missing`` is true, missing incoming secret fields are also
|
||||
copied from the existing config. User config updates need that behavior;
|
||||
workflow model overrides leave missing secrets blank so later enrichment can
|
||||
copy from the current global config.
|
||||
"""
|
||||
if not _same_provider(incoming_cfg, existing_cfg):
|
||||
return incoming_cfg
|
||||
|
||||
for secret_field in SERVICE_SECRET_FIELDS:
|
||||
if secret_field not in existing_cfg:
|
||||
continue
|
||||
|
||||
incoming_secret = incoming_cfg.get(secret_field)
|
||||
existing_secret = existing_cfg[secret_field]
|
||||
if incoming_secret is not None:
|
||||
if contains_masked_key(incoming_secret):
|
||||
incoming_cfg[secret_field] = (
|
||||
existing_secret
|
||||
if masked_value_preserves_full_secret
|
||||
else resolve_masked_api_keys(
|
||||
incoming_secret,
|
||||
existing_secret,
|
||||
)
|
||||
)
|
||||
elif preserve_missing:
|
||||
incoming_cfg[secret_field] = existing_secret
|
||||
|
||||
return incoming_cfg
|
||||
|
||||
|
||||
def merge_user_configurations(
|
||||
existing: UserConfiguration, incoming_partial: Dict[str, dict]
|
||||
) -> UserConfiguration:
|
||||
|
|
@ -41,23 +91,12 @@ def merge_user_configurations(
|
|||
return # nothing to do
|
||||
|
||||
old_cfg = merged.get(service_name, {})
|
||||
|
||||
provider_changed = (
|
||||
old_cfg.get("provider") is not None
|
||||
and incoming_cfg.get("provider") is not None
|
||||
and incoming_cfg.get("provider") != old_cfg.get("provider")
|
||||
)
|
||||
|
||||
if not provider_changed:
|
||||
for secret_field in SERVICE_SECRET_FIELDS:
|
||||
incoming_secret = incoming_cfg.get(secret_field)
|
||||
if incoming_secret is not None:
|
||||
if old_cfg and secret_field in old_cfg:
|
||||
incoming_cfg[secret_field] = resolve_masked_api_keys(
|
||||
incoming_secret, old_cfg[secret_field]
|
||||
)
|
||||
elif secret_field in old_cfg:
|
||||
incoming_cfg[secret_field] = old_cfg[secret_field]
|
||||
if old_cfg:
|
||||
incoming_cfg = _merge_service_secret_fields(
|
||||
incoming_cfg,
|
||||
old_cfg,
|
||||
preserve_missing=True,
|
||||
)
|
||||
|
||||
merged[service_name] = incoming_cfg
|
||||
|
||||
|
|
@ -75,3 +114,46 @@ def merge_user_configurations(
|
|||
merged["timezone"] = incoming_partial["timezone"]
|
||||
|
||||
return UserConfiguration.model_validate(merged)
|
||||
|
||||
|
||||
def merge_workflow_configuration_secrets(
|
||||
incoming_config: dict | None,
|
||||
existing_config: dict | None,
|
||||
) -> dict | None:
|
||||
"""Restore persisted workflow override secrets when the client sends masks.
|
||||
|
||||
Workflow model overrides intentionally persist real keys so a workflow keeps
|
||||
running after the global provider changes. API responses mask those keys, so
|
||||
save requests must merge masked placeholders back to the stored real values.
|
||||
|
||||
Unlike user config updates, a missing workflow override secret is not copied
|
||||
from the existing workflow config. Missing means "copy from current global"
|
||||
during the later enrichment step.
|
||||
"""
|
||||
if not incoming_config or not existing_config:
|
||||
return incoming_config
|
||||
|
||||
merged = copy.deepcopy(incoming_config)
|
||||
incoming_overrides = merged.get("model_overrides")
|
||||
existing_overrides = existing_config.get("model_overrides")
|
||||
if not isinstance(incoming_overrides, dict) or not isinstance(
|
||||
existing_overrides, dict
|
||||
):
|
||||
return merged
|
||||
|
||||
for section in MODEL_OVERRIDE_FIELDS:
|
||||
incoming_section = incoming_overrides.get(section)
|
||||
existing_section = existing_overrides.get(section)
|
||||
if not isinstance(incoming_section, dict) or not isinstance(
|
||||
existing_section, dict
|
||||
):
|
||||
continue
|
||||
|
||||
incoming_overrides[section] = _merge_service_secret_fields(
|
||||
incoming_section,
|
||||
existing_section,
|
||||
preserve_missing=False,
|
||||
masked_value_preserves_full_secret=True,
|
||||
)
|
||||
|
||||
return merged
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.services.configuration.registry import (
|
||||
REGISTRY,
|
||||
|
|
@ -29,6 +31,48 @@ def _build_section_from_override(service_type: ServiceType, override: dict):
|
|||
return config_cls(**override)
|
||||
|
||||
|
||||
_SECRET_FIELDS = ("api_key", "credentials", "aws_access_key", "aws_secret_key")
|
||||
|
||||
|
||||
def enrich_overrides_with_api_keys(
|
||||
model_overrides: dict,
|
||||
user_config: UserConfiguration,
|
||||
) -> dict:
|
||||
"""Copy API keys from the global config into model_overrides where missing.
|
||||
|
||||
When a workflow override selects the same provider as the current global
|
||||
config but omits the API key, the override becomes broken if the global
|
||||
config later switches to a different provider. This function stamps the
|
||||
global provider's API key (and other secret fields) into the override at
|
||||
save time so the override is self-contained.
|
||||
"""
|
||||
result = copy.deepcopy(model_overrides)
|
||||
for section_key in _SECTION_MAP:
|
||||
if section_key not in result:
|
||||
continue
|
||||
override = result[section_key]
|
||||
override_provider = override.get("provider")
|
||||
if not override_provider:
|
||||
continue
|
||||
global_section = getattr(user_config, section_key, None)
|
||||
if global_section is None:
|
||||
continue
|
||||
if getattr(global_section, "provider", None) != override_provider:
|
||||
continue
|
||||
for field in _SECRET_FIELDS:
|
||||
if override.get(field):
|
||||
continue
|
||||
if field == "api_key" and hasattr(global_section, "get_all_api_keys"):
|
||||
all_keys = global_section.get_all_api_keys()
|
||||
if all_keys:
|
||||
override[field] = all_keys[0] if len(all_keys) == 1 else all_keys
|
||||
else:
|
||||
global_value = getattr(global_section, field, None)
|
||||
if global_value is not None:
|
||||
override[field] = global_value
|
||||
return result
|
||||
|
||||
|
||||
def resolve_effective_config(
|
||||
user_config: UserConfiguration,
|
||||
model_overrides: dict | None,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue