diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/agent_cache.py b/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/agent_cache.py index 03cf7acb8..df1ee1b4c 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/agent_cache.py +++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/agent_cache.py @@ -57,6 +57,7 @@ async def build_agent_with_cache( mcp_tools_by_agent: dict[str, list[BaseTool]], disabled_tools: list[str] | None, config_id: str | None, + image_generation_config_id_override: int | None = None, ) -> Any: """Compile the multi-agent graph, serving from cache when key components are stable.""" @@ -91,7 +92,7 @@ async def build_agent_with_cache( # the key, otherwise a hit will leak state across threads. Bump the schema # version when the component list changes shape. cache_key = stable_hash( - "multi-agent-v1", + "multi-agent-v2", config_id, thread_id, user_id, @@ -109,6 +110,10 @@ async def build_agent_with_cache( system_prompt_hash(final_system_prompt), max_input_tokens, sorted(disabled_tools) if disabled_tools else None, + # Bound into the generate_image subagent tool at construction time, so it + # must key the compiled-agent cache to avoid leaking one automation's + # image model into another with the same config_id/search_space. + image_generation_config_id_override, ) return await get_cache().get_or_build(cache_key, builder=_build) diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/factory.py b/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/factory.py index 8451b3b7d..44529d243 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/factory.py +++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/factory.py @@ -62,8 +62,14 @@ async def create_multi_agent_chat_deep_agent( mentioned_document_ids: list[int] | None = None, anon_session_id: str | None = None, filesystem_selection: FilesystemSelection | None = None, + image_generation_config_id: int | None = None, ): - """Deep agent with SurfSense tools/middleware; registry route subagents behind ``task`` when enabled.""" + """Deep agent with SurfSense tools/middleware; registry route subagents behind ``task`` when enabled. + + ``image_generation_config_id`` overrides the search space's image model for + this invocation (used by automations to run on their captured model). When + ``None``, the ``generate_image`` tool resolves the live search-space pref. + """ _t_agent_total = time.perf_counter() apply_litellm_prompt_caching(llm, agent_config=agent_config, thread_id=thread_id) @@ -129,6 +135,9 @@ async def create_multi_agent_chat_deep_agent( "available_document_types": available_document_types, "max_input_tokens": _max_input_tokens, "llm": llm, + # Per-invocation image model override (automations run on their captured + # model). Reaches the generate_image subagent tool via subagent_dependencies. + "image_generation_config_id_override": image_generation_config_id, } _t0 = time.perf_counter() @@ -285,6 +294,7 @@ async def create_multi_agent_chat_deep_agent( mcp_tools_by_agent=mcp_tools_by_agent, disabled_tools=disabled_tools, config_id=config_id, + image_generation_config_id_override=image_generation_config_id, ) _perf_log.info( "[create_agent] Middleware stack + graph compiled in %.3fs", diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/tools/automation/create.py b/surfsense_backend/app/agents/multi_agent_chat/main_agent/tools/automation/create.py index 173d302e5..8e841c1e9 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/tools/automation/create.py +++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/tools/automation/create.py @@ -32,7 +32,8 @@ from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated impo ) from app.automations.schemas.api import AutomationCreate from app.automations.services.automation import AutomationService -from app.db import User, async_session_maker +from app.automations.services.model_policy import get_automation_model_eligibility +from app.db import SearchSpace, User, async_session_maker from app.utils.content_utils import extract_text_content from .prompt import build_draft_prompt @@ -98,6 +99,27 @@ def create_create_automation_tool( declined. Acknowledge once and stop — do NOT retry or pitch variants without a fresh user request. """ + # --- 0. Eligibility gate (fail fast, before drafting + HITL) --- + # Automations may only use premium or BYOK models. Check up front so we + # don't make the user draft + approve a card that can't be saved. + async with async_session_maker() as session: + search_space = await session.get(SearchSpace, search_space_id) + if search_space is None: + return { + "status": "error", + "message": "search space not found in this session", + } + eligibility = get_automation_model_eligibility(search_space) + if not eligibility["allowed"]: + reasons = " ".join(v["reason"] for v in eligibility["violations"]) + return { + "status": "error", + "message": ( + f"{reasons} Update the search space's model settings to a " + "premium or your own (BYOK) model, then try again." + ), + } + # --- 1. Draft via sub-LLM --- prompt = build_draft_prompt(search_space_id=search_space_id, intent=intent) try: diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py index f170a35db..094371760 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py @@ -63,8 +63,14 @@ def _get_global_image_gen_config(config_id: int) -> dict | None: def create_generate_image_tool( search_space_id: int, db_session: AsyncSession, + image_generation_config_id_override: int | None = None, ): - """Create ``generate_image`` with bound search space; DB work uses a per-call session.""" + """Create ``generate_image`` with bound search space; DB work uses a per-call session. + + ``image_generation_config_id_override``: when set (automations running on a + captured model), use this config id instead of reading the search space's + live ``image_generation_config_id``. + """ del db_session # use a fresh per-call session, see below @tool @@ -108,19 +114,27 @@ def create_generate_image_tool( # task's session is shared across every tool; without isolation, # autoflushes from a concurrent writer poison this tool too. async with shielded_async_session() as session: - result = await session.execute( - select(SearchSpace).filter(SearchSpace.id == search_space_id) - ) - search_space = result.scalars().first() - if not search_space: - return _failed( - {"error": "Search space not found"}, - error="Search space not found", + if image_generation_config_id_override is not None: + # Automation run: use the captured image model, insulated from + # later search-space changes. No search-space read needed. + config_id = ( + image_generation_config_id_override or IMAGE_GEN_AUTO_MODE_ID ) + else: + result = await session.execute( + select(SearchSpace).filter(SearchSpace.id == search_space_id) + ) + search_space = result.scalars().first() + if not search_space: + return _failed( + {"error": "Search space not found"}, + error="Search space not found", + ) - config_id = ( - search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID - ) + config_id = ( + search_space.image_generation_config_id + or IMAGE_GEN_AUTO_MODE_ID + ) # Build generation kwargs # NOTE: size, quality, and style are intentionally NOT passed. diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/index.py index 5f76f1d52..ddfcbd7fb 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/index.py @@ -51,5 +51,8 @@ def load_tools( create_generate_image_tool( search_space_id=d["search_space_id"], db_session=d["db_session"], + image_generation_config_id_override=d.get( + "image_generation_config_id_override" + ), ), ] diff --git a/surfsense_backend/app/automations/actions/agent_task/dependencies.py b/surfsense_backend/app/automations/actions/agent_task/dependencies.py index 79107cd65..e3736cc95 100644 --- a/surfsense_backend/app/automations/actions/agent_task/dependencies.py +++ b/surfsense_backend/app/automations/actions/agent_task/dependencies.py @@ -8,6 +8,12 @@ from typing import Any from langgraph.checkpoint.memory import InMemorySaver from sqlalchemy.ext.asyncio import AsyncSession +from app.automations.services.model_policy import ( + AutomationModelPolicyError, + assert_automation_models_billable, + assert_models_billable, +) +from app.db import SearchSpace from app.tasks.chat.streaming.flows.shared.llm_bundle import load_llm_bundle from app.tasks.chat.streaming.flows.shared.pre_stream_setup import ( setup_connector_and_firecrawl, @@ -33,17 +39,48 @@ async def build_dependencies( *, session: AsyncSession, search_space_id: int, + agent_llm_id: int | None = None, + image_generation_config_id: int | None = None, + vision_llm_config_id: int | None = None, ) -> AgentDependencies: """Load the LLM bundle, connector service, and a per-invoke in-memory checkpointer. - Uses the search space's default LLM config (``config_id=-1``). Per-step - model overrides land in a future iteration alongside the ``model`` param. + Resolves the agent LLM from the automation's *captured* model snapshot + (``agent_llm_id``) so runs are insulated from later chat/search-space model + changes. The model policy is enforced here as a runtime backstop: a captured + model that is no longer billable (e.g. a premium global config was removed) + fails the run clearly instead of silently consuming a free model. + + When ``agent_llm_id`` is ``None`` (no captured snapshot — defensive fallback), + fall back to the live search space's ``agent_llm_id`` and validate that. """ + if agent_llm_id is not None: + try: + assert_models_billable( + agent_llm_id=agent_llm_id, + image_generation_config_id=image_generation_config_id, + vision_llm_config_id=vision_llm_config_id, + ) + except AutomationModelPolicyError as exc: + raise DependencyError(str(exc)) from exc + resolved_agent_llm_id = agent_llm_id or 0 + else: + search_space = await session.get(SearchSpace, search_space_id) + if search_space is None: + raise DependencyError(f"search space {search_space_id} not found") + try: + assert_automation_models_billable(search_space) + except AutomationModelPolicyError as exc: + raise DependencyError(str(exc)) from exc + resolved_agent_llm_id = search_space.agent_llm_id or 0 + llm, agent_config, err = await load_llm_bundle( - session, config_id=-1, search_space_id=search_space_id + session, + config_id=resolved_agent_llm_id, + search_space_id=search_space_id, ) if err is not None or llm is None: - raise DependencyError(err or "failed to load default LLM config") + raise DependencyError(err or "failed to load agent LLM config") connector_service, firecrawl_api_key = await setup_connector_and_firecrawl( session, search_space_id=search_space_id diff --git a/surfsense_backend/app/automations/actions/agent_task/invoke.py b/surfsense_backend/app/automations/actions/agent_task/invoke.py index fa02d263f..b645b748d 100644 --- a/surfsense_backend/app/automations/actions/agent_task/invoke.py +++ b/surfsense_backend/app/automations/actions/agent_task/invoke.py @@ -147,6 +147,9 @@ async def run_agent_task( deps = await build_dependencies( session=agent_session, search_space_id=ctx.search_space_id, + agent_llm_id=ctx.agent_llm_id, + image_generation_config_id=ctx.image_generation_config_id, + vision_llm_config_id=ctx.vision_llm_config_id, ) agent = await create_multi_agent_chat_deep_agent( @@ -161,6 +164,7 @@ async def run_agent_task( firecrawl_api_key=deps.firecrawl_api_key, thread_visibility=ChatVisibility.PRIVATE, mentioned_document_ids=mentioned_document_ids, + image_generation_config_id=ctx.image_generation_config_id, ) agent_query, runtime_context = await _resolve_mention_context( diff --git a/surfsense_backend/app/automations/actions/types.py b/surfsense_backend/app/automations/actions/types.py index 2c4ffad8d..453721a43 100644 --- a/surfsense_backend/app/automations/actions/types.py +++ b/surfsense_backend/app/automations/actions/types.py @@ -20,6 +20,12 @@ class ActionContext: step_id: str search_space_id: int creator_user_id: UUID | None + # Captured model snapshot from the automation definition (``definition.models``), + # resolved per run instead of the live search space. ``None`` falls back to the + # search space's current prefs (defensive; should not happen post-capture). + agent_llm_id: int | None = None + image_generation_config_id: int | None = None + vision_llm_config_id: int | None = None ActionHandler = Callable[[dict[str, Any]], Awaitable[Any]] diff --git a/surfsense_backend/app/automations/api/automation.py b/surfsense_backend/app/automations/api/automation.py index b67f0af09..911ae57a6 100644 --- a/surfsense_backend/app/automations/api/automation.py +++ b/surfsense_backend/app/automations/api/automation.py @@ -3,6 +3,7 @@ from __future__ import annotations from fastapi import APIRouter, Depends, Query, status +from pydantic import BaseModel from app.automations.schemas.api import ( AutomationCreate, @@ -16,6 +17,17 @@ from app.automations.services import AutomationService, get_automation_service router = APIRouter() +class ModelEligibilityViolation(BaseModel): + kind: str + config_id: int | None + reason: str + + +class ModelEligibility(BaseModel): + allowed: bool + violations: list[ModelEligibilityViolation] + + @router.post( "/automations", response_model=AutomationDetail, @@ -47,6 +59,23 @@ async def list_automations( ) +@router.get("/automations/model-eligibility", response_model=ModelEligibility) +async def get_automation_model_eligibility( + search_space_id: int = Query(...), + service: AutomationService = Depends(get_automation_service), +) -> ModelEligibility: + """Report whether a search space's models are billable for automations. + + Used by the frontend to gate creation: automations may only use premium + global models or user BYOK models (free models and Auto mode are blocked). + + NOTE: declared before ``/automations/{automation_id}`` so the literal path + isn't captured by the int-typed ``{automation_id}`` route. + """ + result = await service.model_eligibility(search_space_id=search_space_id) + return ModelEligibility.model_validate(result) + + @router.get("/automations/{automation_id}", response_model=AutomationDetail) async def get_automation( automation_id: int, diff --git a/surfsense_backend/app/automations/runtime/executor.py b/surfsense_backend/app/automations/runtime/executor.py index 6a33ab314..da249d8e5 100644 --- a/surfsense_backend/app/automations/runtime/executor.py +++ b/surfsense_backend/app/automations/runtime/executor.py @@ -9,7 +9,10 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.automations.actions.types import ActionContext from app.automations.persistence.enums.run_status import RunStatus from app.automations.persistence.models.run import AutomationRun -from app.automations.schemas.definition.envelope import AutomationDefinition +from app.automations.schemas.definition.envelope import ( + AutomationDefinition, + AutomationModels, +) from app.automations.schemas.definition.plan_step import PlanStep from app.automations.templating import build_run_context @@ -47,7 +50,7 @@ async def execute_run(session: AsyncSession, run_id: int) -> None: for step in definition.plan: template_ctx = _build_template_ctx(run, step_outputs) - action_ctx = _build_action_ctx(session, run, step) + action_ctx = _build_action_ctx(session, run, step, definition.models) result = await execute_step( step=step, template_context=template_ctx, @@ -82,7 +85,7 @@ async def _run_on_failure( return template_ctx = _build_template_ctx(run, step_outputs={}) for step in definition.execution.on_failure: - action_ctx = _build_action_ctx(session, run, step) + action_ctx = _build_action_ctx(session, run, step, definition.models) result = await execute_step( step=step, template_context=template_ctx, @@ -117,7 +120,10 @@ def _build_template_ctx( def _build_action_ctx( - session: AsyncSession, run: AutomationRun, step: PlanStep + session: AsyncSession, + run: AutomationRun, + step: PlanStep, + models: AutomationModels | None, ) -> ActionContext: automation = run.automation return ActionContext( @@ -126,4 +132,9 @@ def _build_action_ctx( step_id=step.step_id, search_space_id=automation.search_space_id, creator_user_id=automation.created_by_user_id, + agent_llm_id=models.agent_llm_id if models else None, + image_generation_config_id=( + models.image_generation_config_id if models else None + ), + vision_llm_config_id=models.vision_llm_config_id if models else None, ) diff --git a/surfsense_backend/app/automations/schemas/definition/__init__.py b/surfsense_backend/app/automations/schemas/definition/__init__.py index 3fb0a739b..72404264e 100644 --- a/surfsense_backend/app/automations/schemas/definition/__init__.py +++ b/surfsense_backend/app/automations/schemas/definition/__init__.py @@ -2,7 +2,7 @@ from __future__ import annotations -from .envelope import AutomationDefinition +from .envelope import AutomationDefinition, AutomationModels from .execution import Execution from .inputs import Inputs from .metadata import Metadata @@ -11,6 +11,7 @@ from .trigger_spec import TriggerSpec __all__ = [ "AutomationDefinition", + "AutomationModels", "Execution", "Inputs", "Metadata", diff --git a/surfsense_backend/app/automations/schemas/definition/envelope.py b/surfsense_backend/app/automations/schemas/definition/envelope.py index f919b2abb..7ca55b1ce 100644 --- a/surfsense_backend/app/automations/schemas/definition/envelope.py +++ b/surfsense_backend/app/automations/schemas/definition/envelope.py @@ -11,6 +11,21 @@ from .plan_step import PlanStep from .trigger_spec import TriggerSpec +class AutomationModels(BaseModel): + """Captured model profile for an automation. + + Snapshotted from the search space's preferences at create time so runs are + insulated from later chat/search-space model changes. Config-id conventions + match the shared scheme (``0`` Auto, ``< 0`` global, ``> 0`` BYOK). + """ + + model_config = ConfigDict(extra="forbid") + + agent_llm_id: int = 0 + image_generation_config_id: int = 0 + vision_llm_config_id: int = 0 + + class AutomationDefinition(BaseModel): """Top-level shape of an automation.""" @@ -24,3 +39,7 @@ class AutomationDefinition(BaseModel): plan: list[PlanStep] = Field(..., min_length=1) execution: Execution = Field(default_factory=Execution) metadata: Metadata = Field(default_factory=Metadata) + # Captured server-side at create() and preserved across update(); resolved + # at runtime instead of the live search space. Optional so drafts/builder + # payloads validate without it. + models: AutomationModels | None = None diff --git a/surfsense_backend/app/automations/services/__init__.py b/surfsense_backend/app/automations/services/__init__.py index 597aca98a..904a3413a 100644 --- a/surfsense_backend/app/automations/services/__init__.py +++ b/surfsense_backend/app/automations/services/__init__.py @@ -3,14 +3,26 @@ from __future__ import annotations from .automation import AutomationService, get_automation_service +from .model_policy import ( + AutomationModelPolicyError, + assert_automation_models_billable, + assert_models_billable, + get_automation_model_eligibility, + get_model_eligibility, +) from .run import RunService, get_run_service from .trigger import TriggerService, get_trigger_service __all__ = [ + "AutomationModelPolicyError", "AutomationService", "RunService", "TriggerService", + "assert_automation_models_billable", + "assert_models_billable", + "get_automation_model_eligibility", "get_automation_service", + "get_model_eligibility", "get_run_service", "get_trigger_service", ] diff --git a/surfsense_backend/app/automations/services/automation.py b/surfsense_backend/app/automations/services/automation.py index 0d2937e0e..6c602d886 100644 --- a/surfsense_backend/app/automations/services/automation.py +++ b/surfsense_backend/app/automations/services/automation.py @@ -18,9 +18,15 @@ from app.automations.schemas.api import ( AutomationUpdate, TriggerCreate, ) +from app.automations.schemas.definition.envelope import AutomationModels +from app.automations.services.model_policy import ( + AutomationModelPolicyError, + assert_automation_models_billable, + get_automation_model_eligibility, +) from app.automations.triggers import get_trigger from app.automations.triggers.schedule import compute_next_fire_at -from app.db import Permission, User, get_async_session +from app.db import Permission, SearchSpace, User, get_async_session from app.users import current_active_user from app.utils.rbac import check_permission @@ -37,6 +43,16 @@ class AutomationService: await self._authorize( payload.search_space_id, Permission.AUTOMATIONS_CREATE.value ) + search_space = await self._assert_models_billable(payload.search_space_id) + + # Snapshot the search space's current (already-validated) model prefs onto + # the definition so runs are insulated from later chat/search-space model + # changes. Captured ids are guaranteed billable by the check above. + payload.definition.models = AutomationModels( + agent_llm_id=search_space.agent_llm_id or 0, + image_generation_config_id=search_space.image_generation_config_id or 0, + vision_llm_config_id=search_space.vision_llm_config_id or 0, + ) automation = Automation( search_space_id=payload.search_space_id, @@ -105,9 +121,15 @@ class AutomationService: if "status" in data: automation.status = data["status"] if "definition" in data: - automation.definition = patch.definition.model_dump( - mode="json", by_alias=True - ) + new_def = patch.definition.model_dump(mode="json", by_alias=True) + # Preserve the captured model snapshot across edits so a definition + # change never silently re-binds the automation to the current chat + # model selection. Backend-managed; survives whether or not the + # client round-trips ``models``. + existing_models = (automation.definition or {}).get("models") + if existing_models is not None: + new_def["models"] = existing_models + automation.definition = new_def automation.version += 1 await self.session.commit() @@ -143,6 +165,40 @@ class AutomationService: ) return automation + async def model_eligibility(self, *, search_space_id: int) -> dict: + """Return whether a search space's models are billable for automations. + + ``{"allowed": bool, "violations": [{kind, config_id, reason}, ...]}``. + """ + await self._authorize(search_space_id, Permission.AUTOMATIONS_READ.value) + search_space = await self.session.get(SearchSpace, search_space_id) + if search_space is None: + raise HTTPException( + status_code=404, detail=f"search space {search_space_id} not found" + ) + return get_automation_model_eligibility(search_space) + + async def _assert_models_billable(self, search_space_id: int) -> SearchSpace: + """Reject creation when the search space's models aren't billable. + + Automations may only use premium global models or user BYOK models; free + global models and Auto mode are blocked. Mirrors the runtime backstop in + ``agent_task`` so users can't save an automation that would fail to run. + + Returns the loaded :class:`SearchSpace` so the caller can capture its + model prefs without a second DB read. + """ + search_space = await self.session.get(SearchSpace, search_space_id) + if search_space is None: + raise HTTPException( + status_code=404, detail=f"search space {search_space_id} not found" + ) + try: + assert_automation_models_billable(search_space) + except AutomationModelPolicyError as exc: + raise HTTPException(status_code=422, detail=str(exc)) from exc + return search_space + async def _authorize(self, search_space_id: int, permission: str) -> None: await check_permission( self.session, diff --git a/surfsense_backend/app/automations/services/model_policy.py b/surfsense_backend/app/automations/services/model_policy.py new file mode 100644 index 000000000..88e9d5f28 --- /dev/null +++ b/surfsense_backend/app/automations/services/model_policy.py @@ -0,0 +1,173 @@ +"""Model-billing policy for automations. + +Automations run unattended, so every run must be **billable**: it may only use +either a premium global model (``billing_tier == "premium"``) or a user-provided +BYOK model (a positive config id pointing at a per-user/per-space DB row). Free +global models and Auto mode are blocked, because Auto can dispatch to a free +deployment and free models aren't metered in premium credits. + +Config id conventions (shared across chat / image / vision): +- ``id == 0`` → Auto mode (``AUTO_MODE_ID`` / ``IMAGE_GEN_AUTO_MODE_ID`` / + ``VISION_AUTO_MODE_ID``). Blocked. +- ``id < 0`` → global YAML/OpenRouter config. Allowed only if premium. +- ``id > 0`` → user BYOK DB row. Always allowed. + +This module is the single source of truth used by both creation-time enforcement +(``AutomationService.create`` and the ``create_automation`` chat tool) and the +runtime backstop (``agent_task`` dependencies). +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal + +if TYPE_CHECKING: + from app.db import SearchSpace + +ModelKind = Literal["llm", "image", "vision"] + +_KIND_LABEL: dict[ModelKind, str] = { + "llm": "agent LLM", + "image": "image generation model", + "vision": "vision model", +} + + +def _is_premium_global(kind: ModelKind, config_id: int) -> bool: + """Return True if a negative (global) config id is a premium tier model.""" + from app.config import config as app_config + + cfg: dict | None = None + if kind == "llm": + from app.agents.new_chat.llm_config import load_global_llm_config_by_id + + cfg = load_global_llm_config_by_id(config_id) + elif kind == "image": + cfg = next( + ( + c + for c in app_config.GLOBAL_IMAGE_GEN_CONFIGS + if c.get("id") == config_id + ), + None, + ) + else: # vision + cfg = next( + ( + c + for c in app_config.GLOBAL_VISION_LLM_CONFIGS + if c.get("id") == config_id + ), + None, + ) + + if not cfg: + return False + return str(cfg.get("billing_tier", "free")).lower() == "premium" + + +def _classify(kind: ModelKind, config_id: int | None) -> tuple[bool, str]: + """Classify a resolved config id as allowed or blocked. + + Returns ``(allowed, reason)``; ``reason`` is empty when allowed. + """ + label = _KIND_LABEL[kind] + + if config_id is None or config_id == 0: + return ( + False, + f"The {label} is set to Auto mode. Automations require an explicit " + "premium model or your own (BYOK) model so every run is billable.", + ) + + if config_id > 0: + # Positive id → user-owned BYOK config. Always allowed. + return True, "" + + # Negative id → global config. Allowed only if premium. + if _is_premium_global(kind, config_id): + return True, "" + + return ( + False, + f"The {label} is a free model. Automations can only use premium models " + "or your own (BYOK) models so every run is billable.", + ) + + +def get_model_eligibility( + *, + agent_llm_id: int | None, + image_generation_config_id: int | None, + vision_llm_config_id: int | None, +) -> dict: + """Return ``{"allowed": bool, "violations": [...]}`` for explicit config ids. + + The ID-based core shared by both the search-space path (creation/eligibility) + and the captured-snapshot path (runtime backstop). Each violation is + ``{"kind", "config_id", "reason"}``. + """ + checks: list[tuple[ModelKind, int | None]] = [ + ("llm", agent_llm_id), + ("image", image_generation_config_id), + ("vision", vision_llm_config_id), + ] + + violations: list[dict] = [] + for kind, config_id in checks: + allowed, reason = _classify(kind, config_id) + if not allowed: + violations.append({"kind": kind, "config_id": config_id, "reason": reason}) + + return {"allowed": not violations, "violations": violations} + + +def get_automation_model_eligibility(search_space: SearchSpace) -> dict: + """Return ``{"allowed": bool, "violations": [...]}`` for a search space. + + Used by the eligibility endpoint and the chat tool's early check. Thin + wrapper over :func:`get_model_eligibility`. + """ + return get_model_eligibility( + agent_llm_id=search_space.agent_llm_id, + image_generation_config_id=search_space.image_generation_config_id, + vision_llm_config_id=search_space.vision_llm_config_id, + ) + + +class AutomationModelPolicyError(Exception): + """Raised when a search space's models are not billable for automations.""" + + def __init__(self, violations: list[dict]) -> None: + self.violations = violations + reasons = "; ".join(v["reason"] for v in violations) + super().__init__( + reasons or "Automations require premium or BYOK models for all model slots." + ) + + +def assert_models_billable( + *, + agent_llm_id: int | None, + image_generation_config_id: int | None, + vision_llm_config_id: int | None, +) -> None: + """Raise :class:`AutomationModelPolicyError` if any explicit id is not billable. + + The ID-based core used by the runtime backstop against an automation's + captured model snapshot. + """ + result = get_model_eligibility( + agent_llm_id=agent_llm_id, + image_generation_config_id=image_generation_config_id, + vision_llm_config_id=vision_llm_config_id, + ) + if not result["allowed"]: + raise AutomationModelPolicyError(result["violations"]) + + +def assert_automation_models_billable(search_space: SearchSpace) -> None: + """Raise :class:`AutomationModelPolicyError` if any model slot is not billable.""" + result = get_automation_model_eligibility(search_space) + if not result["allowed"]: + raise AutomationModelPolicyError(result["violations"]) diff --git a/surfsense_backend/tests/unit/automations/actions/agent_task/test_dependencies.py b/surfsense_backend/tests/unit/automations/actions/agent_task/test_dependencies.py new file mode 100644 index 000000000..ac20b2608 --- /dev/null +++ b/surfsense_backend/tests/unit/automations/actions/agent_task/test_dependencies.py @@ -0,0 +1,174 @@ +"""Lock the runtime model-policy backstop in ``build_dependencies``. + +Automations resolve their LLM from the *captured* ``agent_llm_id`` snapshot (so +runs are insulated from later chat/search-space model changes), and the model +policy is re-checked at run time so a captured model that is no longer billable +fails the run clearly. When no snapshot is present, resolution falls back to the +live search space. +""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + +import pytest + +import app.automations.actions.agent_task.dependencies as deps_mod +from app.automations.actions.agent_task.dependencies import ( + DependencyError, + build_dependencies, +) +from app.automations.services.model_policy import AutomationModelPolicyError + +pytestmark = pytest.mark.unit + + +class _FakeSession: + """Minimal async session whose ``get`` returns a preset search space.""" + + def __init__(self, search_space: Any) -> None: + self._search_space = search_space + + async def get(self, _model: Any, _pk: int) -> Any: + return self._search_space + + +@pytest.fixture +def patched_side_effects(monkeypatch: pytest.MonkeyPatch): + """Stub the connector setup + checkpointer so only policy/LLM logic runs.""" + + async def _fake_setup(_session, *, search_space_id): + return (SimpleNamespace(name="connector"), "fc-key") + + monkeypatch.setattr(deps_mod, "setup_connector_and_firecrawl", _fake_setup) + return None + + +async def test_build_dependencies_resolves_captured_agent_llm_id( + monkeypatch: pytest.MonkeyPatch, patched_side_effects +) -> None: + """The bundle loads with the *captured* ``agent_llm_id``, not the live search space.""" + captured: dict[str, Any] = {} + + async def _fake_load(_session, *, config_id, search_space_id): + captured["config_id"] = config_id + captured["search_space_id"] = search_space_id + return (SimpleNamespace(name="llm"), SimpleNamespace(name="agent_config"), None) + + monkeypatch.setattr(deps_mod, "load_llm_bundle", _fake_load) + # Captured path validates the explicit ids; passes for this test. + monkeypatch.setattr(deps_mod, "assert_models_billable", lambda **_kw: None) + # A different value on the live search space proves we ignore it when a + # snapshot is supplied. + monkeypatch.setattr( + deps_mod, + "assert_automation_models_billable", + lambda _ss: pytest.fail("search-space policy should not run on captured path"), + ) + + search_space = SimpleNamespace(agent_llm_id=-99) + result = await build_dependencies( + session=_FakeSession(search_space), + search_space_id=42, + agent_llm_id=-7, + image_generation_config_id=5, + vision_llm_config_id=-1, + ) + + assert captured == {"config_id": -7, "search_space_id": 42} + assert result.llm.name == "llm" + assert result.firecrawl_api_key == "fc-key" + + +async def test_build_dependencies_validates_captured_ids( + monkeypatch: pytest.MonkeyPatch, patched_side_effects +) -> None: + """The captured ids (not the search space) are what gets policy-checked.""" + seen: dict[str, Any] = {} + + def _capture(**kwargs): + seen.update(kwargs) + + monkeypatch.setattr(deps_mod, "assert_models_billable", _capture) + + async def _fake_load(_session, *, config_id, search_space_id): + return (SimpleNamespace(name="llm"), SimpleNamespace(name="agent_config"), None) + + monkeypatch.setattr(deps_mod, "load_llm_bundle", _fake_load) + + await build_dependencies( + session=_FakeSession(SimpleNamespace(agent_llm_id=0)), + search_space_id=42, + agent_llm_id=-7, + image_generation_config_id=5, + vision_llm_config_id=-1, + ) + + assert seen == { + "agent_llm_id": -7, + "image_generation_config_id": 5, + "vision_llm_config_id": -1, + } + + +async def test_build_dependencies_raises_on_captured_policy_violation( + monkeypatch: pytest.MonkeyPatch, patched_side_effects +) -> None: + """A blocked captured model raises ``DependencyError`` so the step fails clearly.""" + + def _raise(**_kw): + raise AutomationModelPolicyError( + [{"kind": "image", "config_id": -2, "reason": "free model"}] + ) + + monkeypatch.setattr(deps_mod, "assert_models_billable", _raise) + monkeypatch.setattr( + deps_mod, + "load_llm_bundle", + lambda *a, **k: pytest.fail("load_llm_bundle should not be called"), + ) + + with pytest.raises(DependencyError): + await build_dependencies( + session=_FakeSession(SimpleNamespace(agent_llm_id=-7)), + search_space_id=42, + agent_llm_id=-7, + image_generation_config_id=-2, + vision_llm_config_id=-1, + ) + + +async def test_build_dependencies_falls_back_to_search_space( + monkeypatch: pytest.MonkeyPatch, patched_side_effects +) -> None: + """With no captured snapshot, resolve + validate the live search space.""" + captured: dict[str, Any] = {} + + async def _fake_load(_session, *, config_id, search_space_id): + captured["config_id"] = config_id + return (SimpleNamespace(name="llm"), SimpleNamespace(name="agent_config"), None) + + monkeypatch.setattr(deps_mod, "load_llm_bundle", _fake_load) + monkeypatch.setattr(deps_mod, "assert_automation_models_billable", lambda _ss: None) + monkeypatch.setattr( + deps_mod, + "assert_models_billable", + lambda **_kw: pytest.fail("captured policy should not run on fallback path"), + ) + + search_space = SimpleNamespace(agent_llm_id=-7) + result = await build_dependencies( + session=_FakeSession(search_space), search_space_id=42 + ) + + assert captured == {"config_id": -7} + assert result.llm.name == "llm" + + +async def test_build_dependencies_raises_when_search_space_missing( + patched_side_effects, +) -> None: + """A missing search space (fallback path) surfaces as a ``DependencyError``.""" + with pytest.raises(DependencyError): + await build_dependencies(session=_FakeSession(None), search_space_id=999) diff --git a/surfsense_backend/tests/unit/automations/runtime/test_executor_action_ctx.py b/surfsense_backend/tests/unit/automations/runtime/test_executor_action_ctx.py new file mode 100644 index 000000000..d7e3c4a0c --- /dev/null +++ b/surfsense_backend/tests/unit/automations/runtime/test_executor_action_ctx.py @@ -0,0 +1,59 @@ +"""Lock that the executor propagates the captured model snapshot into the +``ActionContext``, so runs resolve their own model (insulated from chat / +search-space changes) and not the live search space. +""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import cast + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from app.automations.runtime.executor import _build_action_ctx +from app.automations.schemas.definition.envelope import AutomationModels +from app.automations.schemas.definition.plan_step import PlanStep + +pytestmark = pytest.mark.unit + + +def _run() -> SimpleNamespace: + return SimpleNamespace( + id=1, + automation=SimpleNamespace(search_space_id=42, created_by_user_id="u-1"), + ) + + +def test_build_action_ctx_propagates_captured_models() -> None: + """``definition.models`` flows onto the ActionContext model fields.""" + models = AutomationModels( + agent_llm_id=-1, + image_generation_config_id=5, + vision_llm_config_id=-1, + ) + ctx = _build_action_ctx( + cast(AsyncSession, None), + _run(), + PlanStep(step_id="s1", action="agent_task"), + models, + ) + + assert ctx.search_space_id == 42 + assert ctx.agent_llm_id == -1 + assert ctx.image_generation_config_id == 5 + assert ctx.vision_llm_config_id == -1 + + +def test_build_action_ctx_none_models_leaves_fields_none() -> None: + """No captured snapshot → model fields are None (defensive fallback path).""" + ctx = _build_action_ctx( + cast(AsyncSession, None), + _run(), + PlanStep(step_id="s1", action="agent_task"), + None, + ) + + assert ctx.agent_llm_id is None + assert ctx.image_generation_config_id is None + assert ctx.vision_llm_config_id is None diff --git a/surfsense_backend/tests/unit/automations/schemas/definition/test_envelope.py b/surfsense_backend/tests/unit/automations/schemas/definition/test_envelope.py index d7b392a1d..25e193ffa 100644 --- a/surfsense_backend/tests/unit/automations/schemas/definition/test_envelope.py +++ b/surfsense_backend/tests/unit/automations/schemas/definition/test_envelope.py @@ -5,7 +5,10 @@ from __future__ import annotations import pytest from pydantic import ValidationError -from app.automations.schemas.definition.envelope import AutomationDefinition +from app.automations.schemas.definition.envelope import ( + AutomationDefinition, + AutomationModels, +) from app.automations.schemas.definition.plan_step import PlanStep pytestmark = pytest.mark.unit @@ -27,6 +30,34 @@ def test_automation_definition_accepts_minimal_valid_input_with_sensible_default assert definition.goal is None assert definition.inputs is None assert definition.triggers == [] + # ``models`` is optional (populated server-side at create()). + assert definition.models is None + + +def test_automation_definition_models_round_trip() -> None: + """The captured ``models`` snapshot survives a model_dump/validate round-trip.""" + definition = AutomationDefinition( + name="Daily digest", + plan=[PlanStep(step_id="s1", action="agent_task")], + models=AutomationModels( + agent_llm_id=-1, + image_generation_config_id=5, + vision_llm_config_id=-1, + ), + ) + + dumped = definition.model_dump(mode="json", by_alias=True) + assert dumped["models"] == { + "agent_llm_id": -1, + "image_generation_config_id": 5, + "vision_llm_config_id": -1, + } + + restored = AutomationDefinition.model_validate(dumped) + assert restored.models is not None + assert restored.models.agent_llm_id == -1 + assert restored.models.image_generation_config_id == 5 + assert restored.models.vision_llm_config_id == -1 def test_automation_definition_rejects_unknown_top_level_field() -> None: diff --git a/surfsense_backend/tests/unit/automations/services/__init__.py b/surfsense_backend/tests/unit/automations/services/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/automations/services/test_automation_service_policy.py b/surfsense_backend/tests/unit/automations/services/test_automation_service_policy.py new file mode 100644 index 000000000..d81302380 --- /dev/null +++ b/surfsense_backend/tests/unit/automations/services/test_automation_service_policy.py @@ -0,0 +1,236 @@ +"""Lock creation-time model-policy enforcement in ``AutomationService``. + +Creation (REST + manual builder) rejects search spaces whose models aren't +billable for automations with HTTP 422, mirroring the runtime backstop. These +tests isolate the new ``_assert_models_billable`` / ``model_eligibility`` paths +without touching the DB commit. +""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + +import pytest +from fastapi import HTTPException + +import app.automations.services.automation as automation_mod +from app.automations.schemas.api import AutomationCreate, AutomationUpdate +from app.automations.schemas.definition.envelope import AutomationDefinition +from app.automations.schemas.definition.plan_step import PlanStep +from app.automations.services.automation import AutomationService +from app.automations.services.model_policy import AutomationModelPolicyError + +pytestmark = pytest.mark.unit + + +class _FakeSession: + def __init__(self, search_space: Any) -> None: + self._search_space = search_space + self.added: list[Any] = [] + self.commits = 0 + + async def get(self, _model: Any, _pk: int) -> Any: + return self._search_space + + def add(self, obj: Any) -> None: + self.added.append(obj) + + async def commit(self) -> None: + self.commits += 1 + + +def _service(search_space: Any) -> AutomationService: + return AutomationService( + session=_FakeSession(search_space), user=SimpleNamespace(id="u-1") + ) + + +def _definition(**kwargs: Any) -> AutomationDefinition: + return AutomationDefinition( + name="A", + plan=[PlanStep(step_id="s1", action="agent_task")], + **kwargs, + ) + + +async def test_assert_models_billable_raises_422_on_violation( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A blocked model maps the policy error to HTTP 422.""" + + def _raise(_ss): + raise AutomationModelPolicyError( + [{"kind": "llm", "config_id": 0, "reason": "Auto mode"}] + ) + + monkeypatch.setattr(automation_mod, "assert_automation_models_billable", _raise) + + service = _service(SimpleNamespace(agent_llm_id=0)) + with pytest.raises(HTTPException) as exc_info: + await service._assert_models_billable(1) + + assert exc_info.value.status_code == 422 + + +async def test_assert_models_billable_raises_404_when_missing( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A missing search space is a 404, not a policy error.""" + monkeypatch.setattr( + automation_mod, "assert_automation_models_billable", lambda _ss: None + ) + + service = _service(None) + with pytest.raises(HTTPException) as exc_info: + await service._assert_models_billable(999) + + assert exc_info.value.status_code == 404 + + +async def test_assert_models_billable_returns_search_space_when_ok( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """When the policy accepts, the loaded search space is returned for reuse.""" + monkeypatch.setattr( + automation_mod, "assert_automation_models_billable", lambda _ss: None + ) + + search_space = SimpleNamespace(agent_llm_id=-1) + service = _service(search_space) + assert await service._assert_models_billable(1) is search_space + + +async def test_create_injects_captured_models_from_search_space( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """create() snapshots the search space's model prefs onto the definition.""" + monkeypatch.setattr( + automation_mod, "assert_automation_models_billable", lambda _ss: None + ) + + async def _noop_authorize(self, *_a, **_k): + return None + + monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize) + + async def _return_added(self, _aid): + return self.session.added[-1] + + monkeypatch.setattr(AutomationService, "_get_with_triggers_or_raise", _return_added) + + search_space = SimpleNamespace( + agent_llm_id=-1, + image_generation_config_id=5, + vision_llm_config_id=-1, + ) + service = _service(search_space) + payload = AutomationCreate( + search_space_id=1, + name="A", + definition=_definition(), + ) + + automation = await service.create(payload) + + assert automation.definition["models"] == { + "agent_llm_id": -1, + "image_generation_config_id": 5, + "vision_llm_config_id": -1, + } + + +async def test_create_treats_unset_prefs_as_auto_zero( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """``None`` search-space prefs are captured as ``0`` (Auto) ids.""" + monkeypatch.setattr( + automation_mod, "assert_automation_models_billable", lambda _ss: None + ) + + async def _noop_authorize(self, *_a, **_k): + return None + + monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize) + + async def _return_added(self, _aid): + return self.session.added[-1] + + monkeypatch.setattr(AutomationService, "_get_with_triggers_or_raise", _return_added) + + search_space = SimpleNamespace( + agent_llm_id=None, + image_generation_config_id=None, + vision_llm_config_id=None, + ) + service = _service(search_space) + payload = AutomationCreate(search_space_id=1, name="A", definition=_definition()) + + automation = await service.create(payload) + + assert automation.definition["models"] == { + "agent_llm_id": 0, + "image_generation_config_id": 0, + "vision_llm_config_id": 0, + } + + +async def test_update_preserves_captured_models( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A definition edit carries over the previously captured ``models``.""" + captured = { + "agent_llm_id": -1, + "image_generation_config_id": 5, + "vision_llm_config_id": -1, + } + existing = SimpleNamespace( + search_space_id=1, + definition={"name": "A", "plan": [], "models": captured}, + version=3, + ) + + async def _noop_authorize(self, *_a, **_k): + return None + + async def _return_existing(self, _aid): + return existing + + monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize) + monkeypatch.setattr( + AutomationService, "_get_with_triggers_or_raise", _return_existing + ) + + service = _service(SimpleNamespace()) + # The incoming patch definition has no ``models`` (frontend strips it). + patch = AutomationUpdate(definition=_definition()) + + result = await service.update(7, patch) + + assert result.definition["models"] == captured + assert result.version == 4 + + +async def test_model_eligibility_authorizes_and_returns_payload( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """``model_eligibility`` checks read access then returns the eligibility dict.""" + authorized: dict[str, Any] = {} + + async def _fake_check_permission(_session, _user, ss_id, permission, _msg): + authorized["search_space_id"] = ss_id + authorized["permission"] = permission + + monkeypatch.setattr(automation_mod, "check_permission", _fake_check_permission) + monkeypatch.setattr( + automation_mod, + "get_automation_model_eligibility", + lambda _ss: {"allowed": False, "violations": [{"kind": "image"}]}, + ) + + service = _service(SimpleNamespace(agent_llm_id=-2)) + result = await service.model_eligibility(search_space_id=5) + + assert result == {"allowed": False, "violations": [{"kind": "image"}]} + assert authorized["search_space_id"] == 5 + assert authorized["permission"] == "automations:read" diff --git a/surfsense_backend/tests/unit/automations/services/test_model_policy.py b/surfsense_backend/tests/unit/automations/services/test_model_policy.py new file mode 100644 index 000000000..2a471b4e9 --- /dev/null +++ b/surfsense_backend/tests/unit/automations/services/test_model_policy.py @@ -0,0 +1,196 @@ +"""Lock the automation model-billing policy. + +Automations may only run on billable models: premium global configs +(``billing_tier == "premium"``) or user BYOK configs (positive id). Free +globals and Auto mode (id == 0 / None) are blocked. These tests pin that rule +across all three model slots (chat LLM, image, vision). +""" + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +import app.automations.services.model_policy as model_policy +from app.automations.services.model_policy import ( + AutomationModelPolicyError, + assert_automation_models_billable, + assert_models_billable, + get_automation_model_eligibility, + get_model_eligibility, +) + +pytestmark = pytest.mark.unit + + +def _search_space(*, llm: int | None, image: int | None, vision: int | None): + """Minimal stand-in for the ``SearchSpace`` ORM row the policy reads.""" + return SimpleNamespace( + agent_llm_id=llm, + image_generation_config_id=image, + vision_llm_config_id=vision, + ) + + +@pytest.fixture +def patched_globals(monkeypatch: pytest.MonkeyPatch): + """Stub the global config sources the policy consults for negative ids. + + Negative ids: -1 is premium, -2 is free, for each of llm/image/vision. + """ + llm_configs = { + -1: {"id": -1, "billing_tier": "premium"}, + -2: {"id": -2, "billing_tier": "free"}, + } + monkeypatch.setattr( + "app.agents.new_chat.llm_config.load_global_llm_config_by_id", + lambda cid: llm_configs.get(cid), + ) + + from app.config import config as app_config + + monkeypatch.setattr( + app_config, + "GLOBAL_IMAGE_GEN_CONFIGS", + [ + {"id": -1, "billing_tier": "premium"}, + {"id": -2, "billing_tier": "free"}, + ], + raising=False, + ) + monkeypatch.setattr( + app_config, + "GLOBAL_VISION_LLM_CONFIGS", + [ + {"id": -1, "billing_tier": "premium"}, + {"id": -2, "billing_tier": "free"}, + ], + raising=False, + ) + return None + + +@pytest.mark.parametrize("kind", ["llm", "image", "vision"]) +def test_byok_positive_id_is_allowed(kind: str, patched_globals) -> None: + """A positive config id is a user-owned BYOK model — always billable.""" + allowed, reason = model_policy._classify(kind, 7) + assert allowed is True + assert reason == "" + + +@pytest.mark.parametrize("kind", ["llm", "image", "vision"]) +@pytest.mark.parametrize("config_id", [0, None]) +def test_auto_mode_is_blocked(kind: str, config_id, patched_globals) -> None: + """Auto mode (id 0) and an unset slot (None) are blocked.""" + allowed, reason = model_policy._classify(kind, config_id) + assert allowed is False + assert "Auto mode" in reason + + +@pytest.mark.parametrize("kind", ["llm", "image", "vision"]) +def test_premium_global_is_allowed(kind: str, patched_globals) -> None: + """A negative (global) id with premium billing tier is allowed.""" + allowed, reason = model_policy._classify(kind, -1) + assert allowed is True + assert reason == "" + + +@pytest.mark.parametrize("kind", ["llm", "image", "vision"]) +def test_free_global_is_blocked(kind: str, patched_globals) -> None: + """A negative (global) id with a free billing tier is blocked.""" + allowed, reason = model_policy._classify(kind, -2) + assert allowed is False + assert "free model" in reason + + +@pytest.mark.parametrize("kind", ["llm", "image", "vision"]) +def test_unknown_global_id_is_blocked(kind: str, patched_globals) -> None: + """A negative id that resolves to no config is treated as not premium.""" + allowed, _ = model_policy._classify(kind, -999) + assert allowed is False + + +def test_eligibility_all_billable(patched_globals) -> None: + """Premium LLM + BYOK image + premium vision → allowed, no violations.""" + search_space = _search_space(llm=-1, image=5, vision=-1) + result = get_automation_model_eligibility(search_space) + assert result == {"allowed": True, "violations": []} + + +def test_eligibility_reports_each_violation(patched_globals) -> None: + """A free LLM, Auto image, and free vision each produce a violation.""" + search_space = _search_space(llm=-2, image=0, vision=-2) + result = get_automation_model_eligibility(search_space) + + assert result["allowed"] is False + kinds = {v["kind"] for v in result["violations"]} + assert kinds == {"llm", "image", "vision"} + # config_id is echoed back for the UI / settings deep-link. + by_kind = {v["kind"]: v["config_id"] for v in result["violations"]} + assert by_kind == {"llm": -2, "image": 0, "vision": -2} + + +def test_assert_raises_with_violations(patched_globals) -> None: + """``assert_automation_models_billable`` raises when any slot is blocked.""" + search_space = _search_space(llm=0, image=5, vision=-1) + with pytest.raises(AutomationModelPolicyError) as exc_info: + assert_automation_models_billable(search_space) + + assert len(exc_info.value.violations) == 1 + assert exc_info.value.violations[0]["kind"] == "llm" + + +def test_assert_passes_when_all_billable(patched_globals) -> None: + """No exception when every slot is premium or BYOK.""" + search_space = _search_space(llm=3, image=-1, vision=4) + assert assert_automation_models_billable(search_space) is None + + +# --- ID-based core (used by the runtime backstop against captured snapshots) --- + + +def test_get_model_eligibility_all_billable(patched_globals) -> None: + """Premium LLM + BYOK image + premium vision (explicit ids) → allowed.""" + result = get_model_eligibility( + agent_llm_id=-1, image_generation_config_id=5, vision_llm_config_id=-1 + ) + assert result == {"allowed": True, "violations": []} + + +def test_get_model_eligibility_reports_each_violation(patched_globals) -> None: + """Free LLM, Auto image, free vision (explicit ids) each produce a violation.""" + result = get_model_eligibility( + agent_llm_id=-2, image_generation_config_id=0, vision_llm_config_id=-2 + ) + assert result["allowed"] is False + by_kind = {v["kind"]: v["config_id"] for v in result["violations"]} + assert by_kind == {"llm": -2, "image": 0, "vision": -2} + + +def test_assert_models_billable_raises(patched_globals) -> None: + """``assert_models_billable`` raises when any explicit id is blocked.""" + with pytest.raises(AutomationModelPolicyError) as exc_info: + assert_models_billable( + agent_llm_id=0, image_generation_config_id=5, vision_llm_config_id=-1 + ) + assert len(exc_info.value.violations) == 1 + assert exc_info.value.violations[0]["kind"] == "llm" + + +def test_assert_models_billable_passes(patched_globals) -> None: + """No exception when every explicit id is premium or BYOK.""" + assert ( + assert_models_billable( + agent_llm_id=3, image_generation_config_id=-1, vision_llm_config_id=4 + ) + is None + ) + + +def test_search_space_wrapper_delegates_to_core(patched_globals) -> None: + """The search-space wrapper produces the same result as the ID core.""" + search_space = _search_space(llm=-2, image=0, vision=-2) + assert get_automation_model_eligibility(search_space) == get_model_eligibility( + agent_llm_id=-2, image_generation_config_id=0, vision_llm_config_id=-2 + ) diff --git a/surfsense_web/app/dashboard/[search_space_id]/automations/automations-content.tsx b/surfsense_web/app/dashboard/[search_space_id]/automations/automations-content.tsx index 756221d38..6bbe55ec9 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/automations/automations-content.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/automations/automations-content.tsx @@ -1,6 +1,8 @@ "use client"; import { ShieldAlert } from "lucide-react"; +import { useAutomationModelEligibility } from "@/hooks/use-automation-model-eligibility"; import { useAutomations } from "@/hooks/use-automations"; +import { AutomationModelGateAlert } from "./components/automation-model-gate-alert"; import { AutomationsEmptyState } from "./components/automations-empty-state"; import { AutomationsHeader } from "./components/automations-header"; import { AutomationsTable } from "./components/automations-table"; @@ -22,6 +24,18 @@ interface AutomationsContentProps { export function AutomationsContent({ searchSpaceId }: AutomationsContentProps) { const { automations, total, loading, error } = useAutomations(); const perms = useAutomationPermissions(); + // Gate creation on billable models (premium/BYOK). Only meaningful for + // users who can create; the eligibility query loads in parallel. + const { data: eligibility, isLoading: eligibilityLoading } = useAutomationModelEligibility( + perms.canCreate ? searchSpaceId : undefined + ); + const modelViolations = eligibility?.violations ?? []; + // Disable create CTAs while loading (avoid a flash of enabled buttons) and + // when the resolved models aren't billable. + const createDisabled = perms.canCreate && (eligibilityLoading || modelViolations.length > 0); + const disabledReason = eligibilityLoading + ? "Checking model eligibility…" + : modelViolations[0]?.reason; if (perms.loading) { // Permissions gate the entire page; defer everything until we know. @@ -77,7 +91,11 @@ export function AutomationsContent({ searchSpaceId }: AutomationsContentProps) { canCreate={perms.canCreate} showCreateCta={false} /> - + ); } @@ -89,7 +107,12 @@ export function AutomationsContent({ searchSpaceId }: AutomationsContentProps) { total={total} loading={loading} canCreate={perms.canCreate} + createDisabled={createDisabled} + disabledReason={disabledReason} /> + {modelViolations.length > 0 && ( + + )} = { + llm: "Agent model", + image: "Image model", + vision: "Vision model", +}; + +/** + * Warns that the search space's models aren't billable for automations. + * + * Automations may only use premium global models or your own (BYOK) models — + * free models and Auto mode are blocked so every run is metered in premium + * credits. Surfaced wherever a user can start creating an automation. + */ +export function AutomationModelGateAlert({ + searchSpaceId, + violations, + className, +}: AutomationModelGateAlertProps) { + if (violations.length === 0) return null; + + return ( + + + Automations need a premium or your own model + +

+ Automations run unattended, so every run must use a premium model or your own (BYOK) + model. Update these in your model settings, then create your automation. +

+
    + {violations.map((violation) => ( +
  • + + {KIND_LABEL[violation.kind]} + + — {violation.reason} +
  • + ))} +
+
+
+ ); +} diff --git a/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-empty-state.tsx b/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-empty-state.tsx index cc54c5e94..ee7dadce6 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-empty-state.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-empty-state.tsx @@ -2,10 +2,14 @@ import { MessageSquarePlus, SquarePen, Workflow } from "lucide-react"; import Link from "next/link"; import { Button } from "@/components/ui/button"; +import type { ModelEligibilityViolation } from "@/contracts/types/automation.types"; +import { AutomationModelGateAlert } from "./automation-model-gate-alert"; interface AutomationsEmptyStateProps { searchSpaceId: number; canCreate: boolean; + /** Model slots that block creation (free/Auto). Empty when eligible. */ + modelViolations?: ModelEligibilityViolation[]; } /** @@ -14,7 +18,13 @@ interface AutomationsEmptyStateProps { * "new automation" form. We surface the chat path explicitly so users * don't go hunting for an "add" button that doesn't exist. */ -export function AutomationsEmptyState({ searchSpaceId, canCreate }: AutomationsEmptyStateProps) { +export function AutomationsEmptyState({ + searchSpaceId, + canCreate, + modelViolations = [], +}: AutomationsEmptyStateProps) { + const modelsBlocked = modelViolations.length > 0; + return (
@@ -26,20 +36,26 @@ export function AutomationsEmptyState({ searchSpaceId, canCreate }: AutomationsE SurfSense drafts the automation for your approval.

{canCreate ? ( -
- - -
+ modelsBlocked ? ( +
+ +
+ ) : ( +
+ + +
+ ) ) : (

You don't have permission to create automations in this search space. diff --git a/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-header.tsx b/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-header.tsx index 8d5fab033..308eaccfb 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-header.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-header.tsx @@ -1,7 +1,9 @@ "use client"; import { MessageSquarePlus, SquarePen } from "lucide-react"; import Link from "next/link"; +import type { ReactNode } from "react"; import { Button } from "@/components/ui/button"; +import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; interface AutomationsHeaderProps { searchSpaceId: number; @@ -14,8 +16,18 @@ interface AutomationsHeaderProps { * there to avoid a duplicate button. */ showCreateCta?: boolean; + /** + * Disable the create CTAs when the search space's models aren't billable + * for automations (free/Auto). When set, a tooltip explains why and the + * buttons render disabled rather than as links. + */ + createDisabled?: boolean; + disabledReason?: string; } +const DEFAULT_DISABLED_REASON = + "Automations need a premium or your own (BYOK) model. Update your model settings to enable."; + /** * Page header: title + count + "Create via chat" CTA. Creation is intent-driven * (the create_automation tool runs inside chat with a HITL approval card), so @@ -27,6 +39,8 @@ export function AutomationsHeader({ loading, canCreate, showCreateCta = true, + createDisabled = false, + disabledReason, }: AutomationsHeaderProps) { return (

@@ -40,20 +54,70 @@ export function AutomationsHeader({
{canCreate && showCreateCta && (
- - + {createDisabled ? ( + <> + } + label="Create manually" + reason={disabledReason ?? DEFAULT_DISABLED_REASON} + /> + } + label="Create via chat" + reason={disabledReason ?? DEFAULT_DISABLED_REASON} + /> + + ) : ( + <> + + + + )}
)}
); } + +function DisabledCta({ + icon, + label, + reason, + variant, +}: { + icon: ReactNode; + label: string; + reason: string; + variant?: "outline"; +}) { + return ( + + + {/* aria-disabled (not `disabled`) keeps the button focusable so the + tooltip is reachable by hover and keyboard; onClick is a no-op. */} + + + {reason} + + ); +} diff --git a/surfsense_web/app/dashboard/[search_space_id]/automations/components/builder/automation-builder-form.tsx b/surfsense_web/app/dashboard/[search_space_id]/automations/components/builder/automation-builder-form.tsx index 1fd37cd3d..f3323da61 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/automations/components/builder/automation-builder-form.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/automations/components/builder/automation-builder-form.tsx @@ -15,6 +15,7 @@ import { import { Button } from "@/components/ui/button"; import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card"; import { Spinner } from "@/components/ui/spinner"; +import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; import { type Automation, automationCreateRequest, @@ -45,6 +46,12 @@ interface AutomationBuilderFormProps { searchSpaceId: number; /** Required in edit mode; seeds the form and trigger reconciliation. */ automation?: Automation; + /** + * When set (create mode only), the search space's models aren't billable + * for automations. Submit is disabled with this reason as the tooltip; the + * orchestrator also renders the full gate alert above the form. + */ + submitDisabledReason?: string; } type Mode = "form" | "json"; @@ -66,6 +73,7 @@ export function AutomationBuilderForm({ mode, searchSpaceId, automation, + submitDisabledReason, }: AutomationBuilderFormProps) { const router = useRouter(); const { mutateAsync: createAutomation } = useAtomValue(createAutomationMutationAtom); @@ -273,6 +281,8 @@ export function AutomationBuilderForm({ } const submitLabel = mode === "edit" ? "Save changes" : "Create automation"; + // Only gate creation; editing an existing automation isn't blocked here. + const submitBlocked = mode === "create" && !!submitDisabledReason; return (
@@ -390,15 +400,39 @@ export function AutomationBuilderForm({ - + {submitBlocked ? ( + + + {/* aria-disabled keeps the button focusable so the tooltip is + reachable by hover and keyboard; onClick is a no-op. */} + + + {submitDisabledReason} + + ) : ( + + )}
); diff --git a/surfsense_web/app/dashboard/[search_space_id]/automations/new/automation-new-content.tsx b/surfsense_web/app/dashboard/[search_space_id]/automations/new/automation-new-content.tsx index 0c983aaf8..a40b0a31b 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/automations/new/automation-new-content.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/automations/new/automation-new-content.tsx @@ -1,5 +1,7 @@ "use client"; import { ShieldAlert } from "lucide-react"; +import { useAutomationModelEligibility } from "@/hooks/use-automation-model-eligibility"; +import { AutomationModelGateAlert } from "../components/automation-model-gate-alert"; import { AutomationBuilderForm } from "../components/builder/automation-builder-form"; import { useAutomationPermissions } from "../hooks/use-automation-permissions"; import { AutomationNewHeader } from "./components/automation-new-header"; @@ -16,6 +18,9 @@ interface AutomationNewContentProps { */ export function AutomationNewContent({ searchSpaceId }: AutomationNewContentProps) { const perms = useAutomationPermissions(); + const { data: eligibility, isLoading: eligibilityLoading } = useAutomationModelEligibility( + perms.canCreate ? searchSpaceId : undefined + ); if (perms.loading) { return
; @@ -33,10 +38,22 @@ export function AutomationNewContent({ searchSpaceId }: AutomationNewContentProp ); } + const modelViolations = eligibility?.violations ?? []; + const submitDisabledReason = eligibilityLoading + ? "Checking model eligibility…" + : modelViolations[0]?.reason; + return ( <> - + {modelViolations.length > 0 && ( + + )} + ); } diff --git a/surfsense_web/atoms/new-llm-config/new-llm-config-mutation.atoms.ts b/surfsense_web/atoms/new-llm-config/new-llm-config-mutation.atoms.ts index f4577a7a9..476d89d4c 100644 --- a/surfsense_web/atoms/new-llm-config/new-llm-config-mutation.atoms.ts +++ b/surfsense_web/atoms/new-llm-config/new-llm-config-mutation.atoms.ts @@ -118,6 +118,12 @@ export const updateLLMPreferencesMutationAtom = atomWithMutation((get) => { cacheKeys.newLLMConfigs.preferences(Number(searchSpaceId)), (old: Record | undefined) => ({ ...old, ...request.data }) ); + // Automation eligibility is derived from these model preferences + // (agent/image/vision). Invalidate it so the automations gate alert + // reflects the new selection without a manual refresh. + queryClient.invalidateQueries({ + queryKey: cacheKeys.automations.modelEligibility(Number(searchSpaceId)), + }); }, onError: (error: Error) => { toast.error(error.message || "Failed to update LLM preferences"); diff --git a/surfsense_web/contracts/types/automation.types.ts b/surfsense_web/contracts/types/automation.types.ts index a93249735..a1f2bd382 100644 --- a/surfsense_web/contracts/types/automation.types.ts +++ b/surfsense_web/contracts/types/automation.types.ts @@ -60,6 +60,15 @@ export const inputs = z.object({ }); export type Inputs = z.infer; +// Captured model snapshot (server-managed). Set at create time and preserved +// across edits so runs are insulated from later chat/search-space model changes. +export const automationModels = z.object({ + agent_llm_id: z.number().int().default(0), + image_generation_config_id: z.number().int().default(0), + vision_llm_config_id: z.number().int().default(0), +}); +export type AutomationModels = z.infer; + export const automationDefinition = z.object({ schema_version: z.string().default("1.0"), name: z.string().min(1).max(200), @@ -69,6 +78,7 @@ export const automationDefinition = z.object({ plan: z.array(planStep).min(1), execution: execution.default(execution.parse({})), metadata: metadata.default(metadata.parse({})), + models: automationModels.nullable().optional(), }); export type AutomationDefinition = z.infer; @@ -191,3 +201,23 @@ export const runListParams = z.object({ offset: z.number().int().min(0).default(0), }); export type RunListParams = z.infer; + +// ============================================================================= +// Model eligibility — mirror app/automations/api/automation.py (ModelEligibility) +// ============================================================================= + +export const modelEligibilityKind = z.enum(["llm", "image", "vision"]); +export type ModelEligibilityKind = z.infer; + +export const modelEligibilityViolation = z.object({ + kind: modelEligibilityKind, + config_id: z.number().nullable(), + reason: z.string(), +}); +export type ModelEligibilityViolation = z.infer; + +export const modelEligibility = z.object({ + allowed: z.boolean(), + violations: z.array(modelEligibilityViolation), +}); +export type ModelEligibility = z.infer; diff --git a/surfsense_web/hooks/use-automation-model-eligibility.ts b/surfsense_web/hooks/use-automation-model-eligibility.ts new file mode 100644 index 000000000..c9bb8b1ea --- /dev/null +++ b/surfsense_web/hooks/use-automation-model-eligibility.ts @@ -0,0 +1,25 @@ +"use client"; +import { useQuery } from "@tanstack/react-query"; +import type { ModelEligibility } from "@/contracts/types/automation.types"; +import { automationsApiService } from "@/lib/apis/automations-api.service"; +import { cacheKeys } from "@/lib/query-client/cache-keys"; + +/** + * Whether the search space's configured models are billable for automations. + * + * Automations may only run on premium global models or user-provided (BYOK) + * models; free global models and Auto mode are blocked so every run is metered + * in premium credits. Creation surfaces use this to gate their CTAs before the + * user invests effort drafting an automation that can't be saved. + * + * Keyed by search space id (not the jotai "current scope" atom) so it can be + * used on the create route as well as the list page. + */ +export function useAutomationModelEligibility(searchSpaceId: number | undefined) { + return useQuery({ + queryKey: cacheKeys.automations.modelEligibility(searchSpaceId ?? 0), + queryFn: () => automationsApiService.getModelEligibility(searchSpaceId as number), + enabled: !!searchSpaceId, + staleTime: 60_000, + }); +} diff --git a/surfsense_web/lib/apis/automations-api.service.ts b/surfsense_web/lib/apis/automations-api.service.ts index ebe72bea5..baaf08799 100644 --- a/surfsense_web/lib/apis/automations-api.service.ts +++ b/surfsense_web/lib/apis/automations-api.service.ts @@ -6,6 +6,7 @@ import { automationCreateRequest, automationListResponse, automationUpdateRequest, + modelEligibility, type RunListParams, run, runListResponse, @@ -62,6 +63,13 @@ class AutomationsApiService { return baseApiService.delete(`${BASE}/${automationId}`); }; + // Whether the search space's models are billable for automations (premium + // global or BYOK). Used to gate creation surfaces before submit. + getModelEligibility = async (searchSpaceId: number) => { + const qs = new URLSearchParams({ search_space_id: String(searchSpaceId) }); + return baseApiService.get(`${BASE}/model-eligibility?${qs.toString()}`, modelEligibility); + }; + // ---- Triggers (sub-resource) -------------------------------------------- addTrigger = async (automationId: number, request: TriggerCreateRequest) => { diff --git a/surfsense_web/lib/query-client/cache-keys.ts b/surfsense_web/lib/query-client/cache-keys.ts index 35724cf94..49bcd8d0e 100644 --- a/surfsense_web/lib/query-client/cache-keys.ts +++ b/surfsense_web/lib/query-client/cache-keys.ts @@ -134,5 +134,7 @@ export const cacheKeys = { ["automations", "runs", automationId, limit, offset] as const, run: (automationId: number, runId: number) => ["automations", "runs", automationId, runId] as const, + modelEligibility: (searchSpaceId: number) => + ["automations", "model-eligibility", searchSpaceId] as const, }, };