From 18606fe3880cd4c53ba912620bdd91ad6e0bccd8 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 10 Jun 2026 21:48:53 +0530 Subject: [PATCH] feat(automations): add model connection policy support --- .../builtin/agent_task/dependencies.py | 30 +++---- .../actions/builtin/agent_task/invoke.py | 8 +- .../app/automations/actions/types.py | 6 +- .../app/automations/runtime/executor.py | 8 +- .../schemas/definition/envelope.py | 10 +-- .../app/automations/services/automation.py | 12 +-- .../app/automations/services/model_policy.py | 87 +++++++------------ 7 files changed, 67 insertions(+), 94 deletions(-) diff --git a/surfsense_backend/app/automations/actions/builtin/agent_task/dependencies.py b/surfsense_backend/app/automations/actions/builtin/agent_task/dependencies.py index 4ef8c52bf..c9584ae2a 100644 --- a/surfsense_backend/app/automations/actions/builtin/agent_task/dependencies.py +++ b/surfsense_backend/app/automations/actions/builtin/agent_task/dependencies.py @@ -39,31 +39,31 @@ async def build_dependencies( *, session: AsyncSession, search_space_id: int, - agent_llm_id: int | None = None, - image_generation_config_id: int | None = None, - vision_llm_config_id: int | None = None, + chat_model_id: int | None = None, + image_gen_model_id: int | None = None, + vision_model_id: int | None = None, ) -> AgentDependencies: """Load the LLM bundle, connector service, and a per-invoke in-memory checkpointer. - Resolves the agent LLM from the automation's *captured* model snapshot - (``agent_llm_id``) so runs are insulated from later chat/search-space model + Resolves the chat model from the automation's *captured* model snapshot + (``chat_model_id``) so runs are insulated from later chat/search-space model changes. The model policy is enforced here as a runtime backstop: a captured model that is no longer billable (e.g. a premium global config was removed) fails the run clearly instead of silently consuming a free model. - When ``agent_llm_id`` is ``None`` (no captured snapshot — defensive fallback), - fall back to the live search space's ``agent_llm_id`` and validate that. + When ``chat_model_id`` is ``None`` (no captured snapshot — defensive fallback), + fall back to the live search space's ``chat_model_id`` and validate that. """ - if agent_llm_id is not None: + if chat_model_id is not None: try: assert_models_billable( - agent_llm_id=agent_llm_id, - image_generation_config_id=image_generation_config_id, - vision_llm_config_id=vision_llm_config_id, + chat_model_id=chat_model_id, + image_gen_model_id=image_gen_model_id, + vision_model_id=vision_model_id, ) except AutomationModelPolicyError as exc: raise DependencyError(str(exc)) from exc - resolved_agent_llm_id = agent_llm_id or 0 + resolved_chat_model_id = chat_model_id or 0 else: search_space = await session.get(SearchSpace, search_space_id) if search_space is None: @@ -72,15 +72,15 @@ async def build_dependencies( assert_automation_models_billable(search_space) except AutomationModelPolicyError as exc: raise DependencyError(str(exc)) from exc - resolved_agent_llm_id = search_space.agent_llm_id or 0 + resolved_chat_model_id = search_space.chat_model_id or 0 llm, agent_config, err = await load_llm_bundle( session, - config_id=resolved_agent_llm_id, + config_id=resolved_chat_model_id, search_space_id=search_space_id, ) if err is not None or llm is None: - raise DependencyError(err or "failed to load agent LLM config") + raise DependencyError(err or "failed to load chat model config") connector_service, firecrawl_api_key = await setup_connector_and_firecrawl( session, search_space_id=search_space_id diff --git a/surfsense_backend/app/automations/actions/builtin/agent_task/invoke.py b/surfsense_backend/app/automations/actions/builtin/agent_task/invoke.py index aa96e4f6e..c3a35930d 100644 --- a/surfsense_backend/app/automations/actions/builtin/agent_task/invoke.py +++ b/surfsense_backend/app/automations/actions/builtin/agent_task/invoke.py @@ -150,9 +150,9 @@ async def run_agent_task( deps = await build_dependencies( session=agent_session, search_space_id=ctx.search_space_id, - agent_llm_id=ctx.agent_llm_id, - image_generation_config_id=ctx.image_generation_config_id, - vision_llm_config_id=ctx.vision_llm_config_id, + chat_model_id=ctx.chat_model_id, + image_gen_model_id=ctx.image_gen_model_id, + vision_model_id=ctx.vision_model_id, ) agent = await create_multi_agent_chat_deep_agent( @@ -167,7 +167,7 @@ async def run_agent_task( firecrawl_api_key=deps.firecrawl_api_key, thread_visibility=ChatVisibility.PRIVATE, mentioned_document_ids=mentioned_document_ids, - image_generation_config_id=ctx.image_generation_config_id, + image_gen_model_id=ctx.image_gen_model_id, ) agent_query, runtime_context = await _resolve_mention_context( diff --git a/surfsense_backend/app/automations/actions/types.py b/surfsense_backend/app/automations/actions/types.py index 453721a43..3ee427512 100644 --- a/surfsense_backend/app/automations/actions/types.py +++ b/surfsense_backend/app/automations/actions/types.py @@ -23,9 +23,9 @@ class ActionContext: # Captured model snapshot from the automation definition (``definition.models``), # resolved per run instead of the live search space. ``None`` falls back to the # search space's current prefs (defensive; should not happen post-capture). - agent_llm_id: int | None = None - image_generation_config_id: int | None = None - vision_llm_config_id: int | None = None + chat_model_id: int | None = None + image_gen_model_id: int | None = None + vision_model_id: int | None = None ActionHandler = Callable[[dict[str, Any]], Awaitable[Any]] diff --git a/surfsense_backend/app/automations/runtime/executor.py b/surfsense_backend/app/automations/runtime/executor.py index da249d8e5..bcdab3940 100644 --- a/surfsense_backend/app/automations/runtime/executor.py +++ b/surfsense_backend/app/automations/runtime/executor.py @@ -132,9 +132,7 @@ def _build_action_ctx( step_id=step.step_id, search_space_id=automation.search_space_id, creator_user_id=automation.created_by_user_id, - agent_llm_id=models.agent_llm_id if models else None, - image_generation_config_id=( - models.image_generation_config_id if models else None - ), - vision_llm_config_id=models.vision_llm_config_id if models else None, + chat_model_id=models.chat_model_id if models else None, + image_gen_model_id=models.image_gen_model_id if models else None, + vision_model_id=models.vision_model_id if models else None, ) diff --git a/surfsense_backend/app/automations/schemas/definition/envelope.py b/surfsense_backend/app/automations/schemas/definition/envelope.py index 7ca55b1ce..787534d4a 100644 --- a/surfsense_backend/app/automations/schemas/definition/envelope.py +++ b/surfsense_backend/app/automations/schemas/definition/envelope.py @@ -14,16 +14,16 @@ from .trigger_spec import TriggerSpec class AutomationModels(BaseModel): """Captured model profile for an automation. - Snapshotted from the search space's preferences at create time so runs are - insulated from later chat/search-space model changes. Config-id conventions + Snapshotted from the search space's model roles at create time so runs are + insulated from later chat/search-space model changes. Model-id conventions match the shared scheme (``0`` Auto, ``< 0`` global, ``> 0`` BYOK). """ model_config = ConfigDict(extra="forbid") - agent_llm_id: int = 0 - image_generation_config_id: int = 0 - vision_llm_config_id: int = 0 + chat_model_id: int = 0 + image_gen_model_id: int = 0 + vision_model_id: int = 0 class AutomationDefinition(BaseModel): diff --git a/surfsense_backend/app/automations/services/automation.py b/surfsense_backend/app/automations/services/automation.py index 4227161e2..1d371c35d 100644 --- a/surfsense_backend/app/automations/services/automation.py +++ b/surfsense_backend/app/automations/services/automation.py @@ -57,9 +57,9 @@ class AutomationService: else: search_space = await self._assert_models_billable(payload.search_space_id) payload.definition.models = AutomationModels( - agent_llm_id=search_space.agent_llm_id or 0, - image_generation_config_id=search_space.image_generation_config_id or 0, - vision_llm_config_id=search_space.vision_llm_config_id or 0, + chat_model_id=search_space.chat_model_id or 0, + image_gen_model_id=search_space.image_gen_model_id or 0, + vision_model_id=search_space.vision_model_id or 0, ) automation = Automation( @@ -225,9 +225,9 @@ class AutomationService: """ try: assert_models_billable( - agent_llm_id=models.agent_llm_id, - image_generation_config_id=models.image_generation_config_id, - vision_llm_config_id=models.vision_llm_config_id, + chat_model_id=models.chat_model_id, + image_gen_model_id=models.image_gen_model_id, + vision_model_id=models.vision_model_id, ) except AutomationModelPolicyError as exc: raise HTTPException(status_code=422, detail=str(exc)) from exc diff --git a/surfsense_backend/app/automations/services/model_policy.py b/surfsense_backend/app/automations/services/model_policy.py index 7e3e46b61..b160fc78d 100644 --- a/surfsense_backend/app/automations/services/model_policy.py +++ b/surfsense_backend/app/automations/services/model_policy.py @@ -24,70 +24,45 @@ from typing import TYPE_CHECKING, Literal if TYPE_CHECKING: from app.db import SearchSpace -ModelKind = Literal["llm", "image", "vision"] +ModelKind = Literal["chat", "image", "vision"] _KIND_LABEL: dict[ModelKind, str] = { - "llm": "agent LLM", + "chat": "chat model", "image": "image generation model", "vision": "vision model", } -def _is_premium_global(kind: ModelKind, config_id: int) -> bool: - """Return True if a negative (global) config id is a premium tier model.""" +def _is_premium_global(model_id: int) -> bool: + """Return True if a negative (global) model id is a premium tier model.""" from app.config import config as app_config - cfg: dict | None = None - if kind == "llm": - from app.agents.chat.runtime.llm_config import ( - load_global_llm_config_by_id, - ) - - cfg = load_global_llm_config_by_id(config_id) - elif kind == "image": - cfg = next( - ( - c - for c in app_config.GLOBAL_IMAGE_GEN_CONFIGS - if c.get("id") == config_id - ), - None, - ) - else: # vision - cfg = next( - ( - c - for c in app_config.GLOBAL_VISION_LLM_CONFIGS - if c.get("id") == config_id - ), - None, - ) - - if not cfg: + model = next((m for m in app_config.GLOBAL_MODELS if m.get("id") == model_id), None) + if not model: return False - return str(cfg.get("billing_tier", "free")).lower() == "premium" + return str(model.get("billing_tier", "free")).lower() == "premium" -def _classify(kind: ModelKind, config_id: int | None) -> tuple[bool, str]: - """Classify a resolved config id as allowed or blocked. +def _classify(kind: ModelKind, model_id: int | None) -> tuple[bool, str]: + """Classify a resolved model id as allowed or blocked. Returns ``(allowed, reason)``; ``reason`` is empty when allowed. """ label = _KIND_LABEL[kind] - if config_id is None or config_id == 0: + if model_id is None or model_id == 0: return ( False, f"The {label} is set to Auto mode. Automations require an explicit " "premium model or your own (BYOK) model so every run is billable.", ) - if config_id > 0: - # Positive id → user-owned BYOK config. Always allowed. + if model_id > 0: + # Positive id -> user/search-space BYOK model. Always allowed. return True, "" - # Negative id → global config. Allowed only if premium. - if _is_premium_global(kind, config_id): + # Negative id -> global model. Allowed only if premium. + if _is_premium_global(model_id): return True, "" return ( @@ -99,27 +74,27 @@ def _classify(kind: ModelKind, config_id: int | None) -> tuple[bool, str]: def get_model_eligibility( *, - agent_llm_id: int | None, - image_generation_config_id: int | None, - vision_llm_config_id: int | None, + chat_model_id: int | None, + image_gen_model_id: int | None, + vision_model_id: int | None, ) -> dict: - """Return ``{"allowed": bool, "violations": [...]}`` for explicit config ids. + """Return ``{"allowed": bool, "violations": [...]}`` for explicit model ids. The ID-based core shared by both the search-space path (creation/eligibility) and the captured-snapshot path (runtime backstop). Each violation is ``{"kind", "config_id", "reason"}``. """ checks: list[tuple[ModelKind, int | None]] = [ - ("llm", agent_llm_id), - ("image", image_generation_config_id), - ("vision", vision_llm_config_id), + ("chat", chat_model_id), + ("image", image_gen_model_id), + ("vision", vision_model_id), ] violations: list[dict] = [] for kind, config_id in checks: allowed, reason = _classify(kind, config_id) if not allowed: - violations.append({"kind": kind, "config_id": config_id, "reason": reason}) + violations.append({"kind": kind, "model_id": config_id, "reason": reason}) return {"allowed": not violations, "violations": violations} @@ -131,9 +106,9 @@ def get_automation_model_eligibility(search_space: SearchSpace) -> dict: wrapper over :func:`get_model_eligibility`. """ return get_model_eligibility( - agent_llm_id=search_space.agent_llm_id, - image_generation_config_id=search_space.image_generation_config_id, - vision_llm_config_id=search_space.vision_llm_config_id, + chat_model_id=search_space.chat_model_id, + image_gen_model_id=search_space.image_gen_model_id, + vision_model_id=search_space.vision_model_id, ) @@ -150,9 +125,9 @@ class AutomationModelPolicyError(Exception): def assert_models_billable( *, - agent_llm_id: int | None, - image_generation_config_id: int | None, - vision_llm_config_id: int | None, + chat_model_id: int | None, + image_gen_model_id: int | None, + vision_model_id: int | None, ) -> None: """Raise :class:`AutomationModelPolicyError` if any explicit id is not billable. @@ -160,9 +135,9 @@ def assert_models_billable( captured model snapshot. """ result = get_model_eligibility( - agent_llm_id=agent_llm_id, - image_generation_config_id=image_generation_config_id, - vision_llm_config_id=vision_llm_config_id, + chat_model_id=chat_model_id, + image_gen_model_id=image_gen_model_id, + vision_model_id=vision_model_id, ) if not result["allowed"]: raise AutomationModelPolicyError(result["violations"])