From bd4a04f2e72a9aebb7a5f614f83bbd43d97487f1 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 13 Jun 2026 12:45:43 +0530 Subject: [PATCH] feat(database-migrations): add migration to remove legacy model config tables and remove stale model connection code --- ...38_add_thread_auto_model_pinning_fields.py | 2 +- .../161_remove_legacy_model_configs.py | 270 +++ .../deliverables/tools/generate_image.py | 2 +- .../app/agents/chat/runtime/llm_config.py | 132 +- .../app/automations/services/model_policy.py | 12 +- surfsense_backend/app/config/__init__.py | 50 +- .../app/config/global_llm_config.example.yaml | 80 +- surfsense_backend/app/db.py | 287 +-- .../prompts/default_system_instructions.py | 4 +- .../system_prompt_composer/composer.py | 3 +- surfsense_backend/app/routes/__init__.py | 4 - .../app/routes/image_generation_routes.py | 291 +-- .../app/routes/model_connections_routes.py | 17 +- .../app/routes/new_llm_config_routes.py | 480 ----- .../app/routes/search_spaces_routes.py | 360 +--- .../app/routes/vision_llm_routes.py | 304 ---- surfsense_backend/app/schemas/__init__.py | 45 - .../app/schemas/image_generation.py | 165 +- .../app/schemas/new_llm_config.py | 256 --- surfsense_backend/app/schemas/vision_llm.py | 116 -- .../app/services/auto_model_pin_service.py | 20 +- .../app/services/billable_calls.py | 57 +- surfsense_backend/app/services/llm_service.py | 24 +- .../app/services/model_list_service.py | 2 +- .../openrouter_integration_service.py | 93 +- .../app/services/pricing_registration.py | 6 +- .../app/services/quality_score.py | 2 +- .../app/services/vision_llm_router_service.py | 160 -- .../app/services/vision_model_list_service.py | 134 -- .../scripts/verify_chat_image_capability.py | 43 +- .../builtin/agent_task/test_dependencies.py | 40 +- .../runtime/test_executor_action_ctx.py | 18 +- .../schemas/definition/test_envelope.py | 18 +- .../test_automation_service_policy.py | 120 +- .../automations/services/test_model_policy.py | 62 +- .../routes/test_byok_supports_image_input.py | 110 -- .../routes/test_global_configs_is_premium.py | 184 -- ...t_global_new_llm_configs_supports_image.py | 106 -- .../tests/unit/routes/test_image_gen_quota.py | 62 +- .../services/test_agent_billing_resolver.py | 232 +-- .../services/test_auto_model_pin_service.py | 135 +- .../test_image_gen_api_base_defense.py | 54 +- .../test_openrouter_integration_service.py | 93 +- .../services/test_pricing_registration.py | 74 - .../tests/unit/services/test_quality_score.py | 2 +- .../test_vision_llm_api_base_defense.py | 77 - surfsense_evals/README.md | 4 +- .../parser_compare/run_artifact.json | 2 +- .../src/surfsense_evals/core/cli.py | 73 +- .../core/clients/search_space.py | 116 +- .../src/surfsense_evals/core/config.py | 21 +- .../src/surfsense_evals/core/registry.py | 4 +- .../src/surfsense_evals/core/vision_llm.py | 4 +- .../suites/medical/medxpertqa/runner.py | 2 +- .../multimodal_doc/mmlongbench/runner.py | 2 +- .../multimodal_doc/parser_compare/runner.py | 2 +- .../suites/research/crag/runner.py | 2 +- .../suites/research/frames/runner.py | 2 +- surfsense_evals/tests/core/test_clients.py | 23 +- surfsense_evals/tests/core/test_config.py | 30 +- .../tests/test_integration_smoke.py | 2 +- .../image-models/page.tsx | 6 - .../search-space-settings/roles/page.tsx | 6 - .../vision-models/page.tsx | 6 - .../image-gen-config-mutation.atoms.ts | 96 - .../image-gen-config-query.atoms.ts | 33 - .../new-llm-config-mutation.atoms.ts | 132 -- .../new-llm-config-query.atoms.ts | 98 -- .../vision-llm-config-mutation.atoms.ts | 87 - .../vision-llm-config-query.atoms.ts | 51 - .../components/new-chat/chat-header.tsx | 153 +- .../settings/agent-model-manager.tsx | 423 ----- .../settings/image-model-manager.tsx | 489 ------ .../components/settings/llm-role-manager.tsx | 443 ----- .../settings/vision-model-manager.tsx | 486 ----- .../components/shared/image-config-dialog.tsx | 456 ----- .../components/shared/llm-config-form.tsx | 527 ------ .../components/shared/model-config-dialog.tsx | 339 ---- .../shared/vision-config-dialog.tsx | 478 ----- .../contracts/enums/image-gen-providers.ts | 105 -- surfsense_web/contracts/enums/llm-models.ts | 1558 ----------------- .../contracts/enums/llm-providers.ts | 197 --- .../contracts/enums/vision-providers.ts | 168 -- .../contracts/types/new-llm-config.types.ts | 476 ----- .../lib/apis/image-gen-config-api.service.ts | 81 - .../lib/apis/new-llm-config-api.service.ts | 178 -- .../lib/apis/vision-llm-config-api.service.ts | 63 - surfsense_web/lib/query-client/cache-keys.ts | 19 - surfsense_web/messages/en.json | 8 - surfsense_web/messages/es.json | 31 +- surfsense_web/messages/hi.json | 31 +- surfsense_web/messages/pt.json | 31 +- surfsense_web/messages/zh.json | 46 +- 93 files changed, 956 insertions(+), 11442 deletions(-) create mode 100644 surfsense_backend/alembic/versions/161_remove_legacy_model_configs.py delete mode 100644 surfsense_backend/app/routes/new_llm_config_routes.py delete mode 100644 surfsense_backend/app/routes/vision_llm_routes.py delete mode 100644 surfsense_backend/app/schemas/new_llm_config.py delete mode 100644 surfsense_backend/app/schemas/vision_llm.py delete mode 100644 surfsense_backend/app/services/vision_llm_router_service.py delete mode 100644 surfsense_backend/app/services/vision_model_list_service.py delete mode 100644 surfsense_backend/tests/unit/routes/test_byok_supports_image_input.py delete mode 100644 surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py delete mode 100644 surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py delete mode 100644 surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py delete mode 100644 surfsense_web/app/dashboard/[search_space_id]/search-space-settings/image-models/page.tsx delete mode 100644 surfsense_web/app/dashboard/[search_space_id]/search-space-settings/roles/page.tsx delete mode 100644 surfsense_web/app/dashboard/[search_space_id]/search-space-settings/vision-models/page.tsx delete mode 100644 surfsense_web/atoms/image-gen-config/image-gen-config-mutation.atoms.ts delete mode 100644 surfsense_web/atoms/image-gen-config/image-gen-config-query.atoms.ts delete mode 100644 surfsense_web/atoms/new-llm-config/new-llm-config-mutation.atoms.ts delete mode 100644 surfsense_web/atoms/new-llm-config/new-llm-config-query.atoms.ts delete mode 100644 surfsense_web/atoms/vision-llm-config/vision-llm-config-mutation.atoms.ts delete mode 100644 surfsense_web/atoms/vision-llm-config/vision-llm-config-query.atoms.ts delete mode 100644 surfsense_web/components/settings/agent-model-manager.tsx delete mode 100644 surfsense_web/components/settings/image-model-manager.tsx delete mode 100644 surfsense_web/components/settings/llm-role-manager.tsx delete mode 100644 surfsense_web/components/settings/vision-model-manager.tsx delete mode 100644 surfsense_web/components/shared/image-config-dialog.tsx delete mode 100644 surfsense_web/components/shared/llm-config-form.tsx delete mode 100644 surfsense_web/components/shared/model-config-dialog.tsx delete mode 100644 surfsense_web/components/shared/vision-config-dialog.tsx delete mode 100644 surfsense_web/contracts/enums/image-gen-providers.ts delete mode 100644 surfsense_web/contracts/enums/llm-models.ts delete mode 100644 surfsense_web/contracts/enums/llm-providers.ts delete mode 100644 surfsense_web/contracts/enums/vision-providers.ts delete mode 100644 surfsense_web/contracts/types/new-llm-config.types.ts delete mode 100644 surfsense_web/lib/apis/image-gen-config-api.service.ts delete mode 100644 surfsense_web/lib/apis/new-llm-config-api.service.ts delete mode 100644 surfsense_web/lib/apis/vision-llm-config-api.service.ts diff --git a/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py b/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py index fba621a0c..8c74b637b 100644 --- a/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py +++ b/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py @@ -4,7 +4,7 @@ Revision ID: 138 Revises: 137 Create Date: 2026-04-30 -Add a single thread-level column to persist the Auto (Fastest) model pin: +Add a single thread-level column to persist the Auto model pin: - pinned_llm_config_id: concrete resolved global LLM config id used for this thread. NULL means "no pin; Auto will resolve on next turn". diff --git a/surfsense_backend/alembic/versions/161_remove_legacy_model_configs.py b/surfsense_backend/alembic/versions/161_remove_legacy_model_configs.py new file mode 100644 index 000000000..2108d763c --- /dev/null +++ b/surfsense_backend/alembic/versions/161_remove_legacy_model_configs.py @@ -0,0 +1,270 @@ +"""remove legacy model config tables + +Revision ID: 161 +Revises: 160 +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql +from sqlalchemy.types import TypeEngine + +from alembic import op + +revision: str = "161" +down_revision: str | None = "160" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +litellm_provider = postgresql.ENUM( + "OPENAI", + "ANTHROPIC", + "GOOGLE", + "AZURE_OPENAI", + "BEDROCK", + "VERTEX_AI", + "GROQ", + "COHERE", + "MISTRAL", + "DEEPSEEK", + "XAI", + "OPENROUTER", + "TOGETHER_AI", + "FIREWORKS_AI", + "REPLICATE", + "PERPLEXITY", + "OLLAMA", + "ALIBABA_QWEN", + "MOONSHOT", + "ZHIPU", + "ANYSCALE", + "DEEPINFRA", + "CEREBRAS", + "SAMBANOVA", + "AI21", + "CLOUDFLARE", + "DATABRICKS", + "COMETAPI", + "HUGGINGFACE", + "GITHUB_MODELS", + "MINIMAX", + "CUSTOM", + name="litellmprovider", + create_type=False, +) +image_gen_provider = postgresql.ENUM( + "OPENAI", + "AZURE_OPENAI", + "GOOGLE", + "VERTEX_AI", + "BEDROCK", + "RECRAFT", + "OPENROUTER", + "XINFERENCE", + "NSCALE", + name="imagegenprovider", + create_type=False, +) +vision_provider = postgresql.ENUM( + "OPENAI", + "ANTHROPIC", + "GOOGLE", + "AZURE_OPENAI", + "VERTEX_AI", + "BEDROCK", + "XAI", + "OPENROUTER", + "OLLAMA", + "GROQ", + "TOGETHER_AI", + "FIREWORKS_AI", + "DEEPSEEK", + "MISTRAL", + "CUSTOM", + name="visionprovider", + create_type=False, +) + + +def _table_exists(table_name: str) -> bool: + return table_name in sa.inspect(op.get_bind()).get_table_names() + + +def _column_exists(table_name: str, column_name: str) -> bool: + if not _table_exists(table_name): + return False + return column_name in { + column["name"] for column in sa.inspect(op.get_bind()).get_columns(table_name) + } + + +def _drop_column_if_exists(table_name: str, column_name: str) -> None: + if _column_exists(table_name, column_name): + op.drop_column(table_name, column_name) + + +def _rename_column_if_exists( + table_name: str, + old_column_name: str, + new_column_name: str, + *, + existing_type: TypeEngine, + existing_nullable: bool = True, +) -> None: + if _column_exists(table_name, old_column_name) and not _column_exists( + table_name, new_column_name + ): + op.alter_column( + table_name, + old_column_name, + new_column_name=new_column_name, + existing_type=existing_type, + existing_nullable=existing_nullable, + ) + + +def upgrade() -> None: + for table_name in ( + "new_llm_configs", + "vision_llm_configs", + "image_generation_configs", + ): + if _table_exists(table_name): + op.drop_table(table_name) + + _drop_column_if_exists("searchspaces", "agent_llm_id") + _drop_column_if_exists("searchspaces", "image_generation_config_id") + _drop_column_if_exists("searchspaces", "vision_llm_config_id") + + _rename_column_if_exists( + "image_generations", + "image_generation_config_id", + "image_gen_model_id", + existing_type=sa.Integer(), + ) + + op.execute("DROP TYPE IF EXISTS litellmprovider") + op.execute("DROP TYPE IF EXISTS imagegenprovider") + op.execute("DROP TYPE IF EXISTS visionprovider") + + +def downgrade() -> None: + bind = op.get_bind() + litellm_provider.create(bind, checkfirst=True) + image_gen_provider.create(bind, checkfirst=True) + vision_provider.create(bind, checkfirst=True) + + _rename_column_if_exists( + "image_generations", + "image_gen_model_id", + "image_generation_config_id", + existing_type=sa.Integer(), + ) + + if _table_exists("searchspaces"): + if not _column_exists("searchspaces", "agent_llm_id"): + op.add_column( + "searchspaces", + sa.Column("agent_llm_id", sa.Integer(), nullable=True), + ) + if not _column_exists("searchspaces", "image_generation_config_id"): + op.add_column( + "searchspaces", + sa.Column("image_generation_config_id", sa.Integer(), nullable=True), + ) + if not _column_exists("searchspaces", "vision_llm_config_id"): + op.add_column( + "searchspaces", + sa.Column("vision_llm_config_id", sa.Integer(), nullable=True), + ) + + if not _table_exists("image_generation_configs"): + op.create_table( + "image_generation_configs", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("name", sa.String(length=100), nullable=False), + sa.Column("description", sa.String(length=500), nullable=True), + sa.Column("provider", image_gen_provider, nullable=False), + sa.Column("custom_provider", sa.String(length=100), nullable=True), + sa.Column("model_name", sa.String(length=100), nullable=False), + sa.Column("api_key", sa.String(), nullable=False), + sa.Column("api_base", sa.String(length=500), nullable=True), + sa.Column("api_version", sa.String(length=50), nullable=True), + sa.Column("litellm_params", sa.JSON(), nullable=True), + sa.Column("search_space_id", sa.Integer(), nullable=False), + sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.ForeignKeyConstraint( + ["search_space_id"], ["searchspaces.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_image_generation_configs_name"), + "image_generation_configs", + ["name"], + unique=False, + ) + + if not _table_exists("vision_llm_configs"): + op.create_table( + "vision_llm_configs", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("name", sa.String(length=100), nullable=False), + sa.Column("description", sa.String(length=500), nullable=True), + sa.Column("provider", vision_provider, nullable=False), + sa.Column("custom_provider", sa.String(length=100), nullable=True), + sa.Column("model_name", sa.String(length=100), nullable=False), + sa.Column("api_key", sa.String(), nullable=False), + sa.Column("api_base", sa.String(length=500), nullable=True), + sa.Column("api_version", sa.String(length=50), nullable=True), + sa.Column("litellm_params", sa.JSON(), nullable=True), + sa.Column("search_space_id", sa.Integer(), nullable=False), + sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.ForeignKeyConstraint( + ["search_space_id"], ["searchspaces.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_vision_llm_configs_name"), + "vision_llm_configs", + ["name"], + unique=False, + ) + + if not _table_exists("new_llm_configs"): + op.create_table( + "new_llm_configs", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("name", sa.String(length=100), nullable=False), + sa.Column("description", sa.String(length=500), nullable=True), + sa.Column("provider", litellm_provider, nullable=False), + sa.Column("custom_provider", sa.String(length=100), nullable=True), + sa.Column("model_name", sa.String(length=100), nullable=False), + sa.Column("api_key", sa.String(), nullable=False), + sa.Column("api_base", sa.String(length=500), nullable=True), + sa.Column("litellm_params", sa.JSON(), nullable=True), + sa.Column("system_instructions", sa.Text(), nullable=False), + sa.Column("use_default_system_instructions", sa.Boolean(), nullable=False), + sa.Column("citations_enabled", sa.Boolean(), nullable=False), + sa.Column("search_space_id", sa.Integer(), nullable=False), + sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.ForeignKeyConstraint( + ["search_space_id"], ["searchspaces.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_new_llm_configs_name"), + "new_llm_configs", + ["name"], + unique=False, + ) diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py index 505831faa..d847e021a 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py @@ -215,7 +215,7 @@ def create_generate_image_tool( prompt=prompt, model=getattr(response, "_hidden_params", {}).get("model"), n=n, - image_generation_config_id=config_id, + image_gen_model_id=config_id, response_data=response_dict, search_space_id=search_space_id, access_token=access_token, diff --git a/surfsense_backend/app/agents/chat/runtime/llm_config.py b/surfsense_backend/app/agents/chat/runtime/llm_config.py index efc188df8..e00d16ee8 100644 --- a/surfsense_backend/app/agents/chat/runtime/llm_config.py +++ b/surfsense_backend/app/agents/chat/runtime/llm_config.py @@ -24,8 +24,6 @@ from langchain_core.messages import AIMessage, BaseMessage from langchain_core.outputs import ChatGenerationChunk, ChatResult from langchain_litellm import ChatLiteLLM from litellm import get_model_info -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession from app.agents.chat.runtime.prompt_caching import ( apply_litellm_prompt_caching, @@ -34,7 +32,6 @@ from app.services.llm_router_service import ( AUTO_MODE_ID, ChatLiteLLMRouter, _sanitize_content, - is_auto_mode, ) @@ -130,7 +127,7 @@ class AgentConfig: """ Complete configuration for the SurfSense agent. - This combines LLM settings with prompt configuration from NewLLMConfig. + This combines resolved model settings with prompt configuration. Supports Auto mode metadata (ID 0). Runtime callers must resolve Auto to a concrete global or BYOK model before constructing ChatLiteLLM. """ @@ -180,7 +177,7 @@ class AgentConfig: use_default_system_instructions=True, citations_enabled=True, config_id=AUTO_MODE_ID, - config_name="Auto (Fastest)", + config_name="Auto", is_auto_mode=True, billing_tier="free", is_premium=False, @@ -191,57 +188,12 @@ class AgentConfig: supports_image_input=True, ) - @classmethod - def from_new_llm_config(cls, config) -> "AgentConfig": - """Build an AgentConfig from a NewLLMConfig database model.""" - # Lazy import: keeps provider_capabilities (and litellm) out of init order. - from app.services.provider_capabilities import derive_supports_image_input - - provider_value = ( - config.provider.value - if hasattr(config.provider, "value") - else str(config.provider) - ) - litellm_params = config.litellm_params or {} - base_model = ( - litellm_params.get("base_model") - if isinstance(litellm_params, dict) - else None - ) - - return cls( - provider=provider_value, - model_name=config.model_name, - api_key=config.api_key, - api_base=config.api_base, - custom_provider=config.custom_provider, - litellm_params=config.litellm_params, - system_instructions=config.system_instructions, - use_default_system_instructions=config.use_default_system_instructions, - citations_enabled=config.citations_enabled, - config_id=config.id, - config_name=config.name, - is_auto_mode=False, - billing_tier="free", - is_premium=False, - anonymous_enabled=False, - quota_reserve_tokens=None, - # BYOK rows have no curated flag; ask LiteLLM (default-allow on - # unknown). The streaming safety net still blocks explicit text-only. - supports_image_input=derive_supports_image_input( - provider=provider_value.lower(), - model_name=config.model_name, - base_model=base_model, - custom_provider=config.custom_provider, - ), - ) - @classmethod def from_yaml_config(cls, yaml_config: dict) -> "AgentConfig": """Build an AgentConfig from a YAML configuration dictionary. - Supports the same prompt fields as NewLLMConfig (system_instructions, - use_default_system_instructions, citations_enabled). + Supports prompt fields such as system_instructions, + use_default_system_instructions, and citations_enabled. """ # Lazy import: keeps provider_capabilities (and litellm) out of init order. from app.services.provider_capabilities import derive_supports_image_input @@ -334,82 +286,6 @@ def load_global_llm_config_by_id(llm_config_id: int) -> dict | None: return load_llm_config_from_yaml(llm_config_id) -async def load_new_llm_config_from_db( - session: AsyncSession, - config_id: int, -) -> "AgentConfig | None": - """Load a NewLLMConfig from the database by ID.""" - from app.db import NewLLMConfig - - try: - result = await session.execute( - select(NewLLMConfig).filter(NewLLMConfig.id == config_id) - ) - config = result.scalars().first() - - if not config: - print(f"Error: NewLLMConfig with id {config_id} not found") - return None - - return AgentConfig.from_new_llm_config(config) - except Exception as e: - print(f"Error loading NewLLMConfig from database: {e}") - return None - - -async def load_agent_llm_config_for_search_space( - session: AsyncSession, - search_space_id: int, -) -> "AgentConfig | None": - """Load the chat model config for a search space via its agent_llm_id. - - Positive id -> DB; negative -> YAML; None -> first global config (-1). - """ - from app.db import SearchSpace - - try: - result = await session.execute( - select(SearchSpace).filter(SearchSpace.id == search_space_id) - ) - search_space = result.scalars().first() - - if not search_space: - print(f"Error: SearchSpace with id {search_space_id} not found") - return None - - config_id = ( - search_space.agent_llm_id if search_space.agent_llm_id is not None else -1 - ) - return await load_agent_config(session, config_id, search_space_id) - except Exception as e: - print(f"Error loading chat model config for search space {search_space_id}: {e}") - return None - - -async def load_agent_config( - session: AsyncSession, - config_id: int, - search_space_id: int | None = None, -) -> "AgentConfig | None": - """Main config loader: id 0 -> Auto mode; negative -> YAML; positive -> DB.""" - if is_auto_mode(config_id): - return AgentConfig.from_auto_mode() - - if config_id < 0: - # In-memory covers static YAML + dynamic OpenRouter configs. - from app.config import config as app_config - - for cfg in app_config.GLOBAL_LLM_CONFIGS: - if cfg.get("id") == config_id: - return AgentConfig.from_yaml_config(cfg) - yaml_config = load_llm_config_from_yaml(config_id) - if yaml_config: - return AgentConfig.from_yaml_config(yaml_config) - return None - else: - return await load_new_llm_config_from_db(session, config_id) - - def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None: """Create a ChatLiteLLM instance from a global LLM config dictionary.""" if llm_config.get("custom_provider"): diff --git a/surfsense_backend/app/automations/services/model_policy.py b/surfsense_backend/app/automations/services/model_policy.py index b160fc78d..e18264246 100644 --- a/surfsense_backend/app/automations/services/model_policy.py +++ b/surfsense_backend/app/automations/services/model_policy.py @@ -2,11 +2,11 @@ Automations run unattended, so every run must be **billable**: it may only use either a premium global model (``billing_tier == "premium"``) or a user-provided -BYOK model (a positive config id pointing at a per-user/per-space DB row). Free +BYOK model (a positive model id pointing at a per-user/per-space DB row). Free global models and Auto mode are blocked, because Auto can dispatch to a free deployment and free models aren't metered in premium credits. -Config id conventions (shared across chat / image / vision): +Model id conventions (shared across chat / image / vision): - ``id == 0`` → Auto mode (``AUTO_MODE_ID`` / ``IMAGE_GEN_AUTO_MODE_ID`` / ``VISION_AUTO_MODE_ID``). Blocked. - ``id < 0`` → global YAML/OpenRouter config. Allowed only if premium. @@ -82,7 +82,7 @@ def get_model_eligibility( The ID-based core shared by both the search-space path (creation/eligibility) and the captured-snapshot path (runtime backstop). Each violation is - ``{"kind", "config_id", "reason"}``. + ``{"kind", "model_id", "reason"}``. """ checks: list[tuple[ModelKind, int | None]] = [ ("chat", chat_model_id), @@ -91,10 +91,10 @@ def get_model_eligibility( ] violations: list[dict] = [] - for kind, config_id in checks: - allowed, reason = _classify(kind, config_id) + for kind, model_id in checks: + allowed, reason = _classify(kind, model_id) if not allowed: - violations.append({"kind": kind, "model_id": config_id, "reason": reason}) + violations.append({"kind": kind, "model_id": model_id, "reason": reason}) return {"allowed": not violations, "violations": violations} diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index d690c1d7d..8c9662aa8 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -119,7 +119,7 @@ def load_global_llm_configs(): else: seen_slugs[slug] = cfg.get("id", 0) - # Stamp Auto (Fastest) ranking metadata. YAML configs are always + # Stamp Auto ranking metadata. YAML configs are always # Tier A — operator-curated, locked first when premium-eligible. # The OpenRouter refresh tick later re-stamps health for any cfg # whose provider == "openrouter" via _enrich_health. @@ -210,42 +210,6 @@ def load_global_image_gen_configs(): return [] -def load_global_vision_llm_configs(): - data = _global_config_data() - if not data: - return [] - - try: - configs = copy.deepcopy(data.get("global_vision_llm_configs", []) or []) - for cfg in configs: - if isinstance(cfg, dict): - cfg.setdefault("billing_tier", "free") - return configs - except Exception as e: - print(f"Warning: Failed to load global vision LLM configs: {e}") - return [] - - -def load_vision_llm_router_settings(): - default_settings = { - "routing_strategy": "usage-based-routing", - "num_retries": 3, - "allowed_fails": 3, - "cooldown_time": 60, - } - - data = _global_config_data() - if not data: - return default_settings - - try: - settings = data.get("vision_llm_router_settings", {}) - return {**default_settings, **settings} - except Exception as e: - print(f"Warning: Failed to load vision LLM router settings: {e}") - return default_settings - - def load_image_gen_router_settings(): """ Load router settings for image generation Auto mode from YAML file. @@ -482,12 +446,6 @@ def initialize_image_gen_router(): print(f"Warning: Failed to initialize Image Generation Router: {e}") -def initialize_vision_llm_router(): - # Retired: vision Auto now uses shared capability-filtered model selection - # over GLOBAL/BYOK chat models with supports_image_input=true. - return - - class Config: # Check if ffmpeg is installed if not is_ffmpeg_installed(): @@ -869,12 +827,6 @@ class Config: # Router settings for Image Generation Auto mode IMAGE_GEN_ROUTER_SETTINGS = load_image_gen_router_settings() - # Global Vision LLM Configurations (optional) - GLOBAL_VISION_LLM_CONFIGS = load_global_vision_llm_configs() - - # Router settings for Vision LLM Auto mode - VISION_LLM_ROUTER_SETTINGS = load_vision_llm_router_settings() - # Virtual GLOBAL connection/model catalog. This is server-only metadata # derived from global_llm_config.yaml; GLOBAL keys are not stored in DB. from app.services.global_model_catalog import ( diff --git a/surfsense_backend/app/config/global_llm_config.example.yaml b/surfsense_backend/app/config/global_llm_config.example.yaml index 06676511f..c5b65fee0 100644 --- a/surfsense_backend/app/config/global_llm_config.example.yaml +++ b/surfsense_backend/app/config/global_llm_config.example.yaml @@ -433,87 +433,11 @@ global_image_generation_configs: # rpm: 30 # litellm_params: {} -# ============================================================================= -# Vision LLM Configuration -# ============================================================================= -# These configurations power the vision autocomplete feature (screenshot analysis). -# Only vision-capable models should be used here (e.g. GPT-4o, Gemini Pro, Claude 3). -# Supported providers: OpenAI, Anthropic, Google, Azure OpenAI, Vertex AI, Bedrock, -# xAI, OpenRouter, Ollama, Groq, Together AI, Fireworks AI, DeepSeek, Mistral, Custom -# -# Auto mode (ID 0) uses LiteLLM Router for load balancing across all vision configs. - -# Router Settings for Vision LLM Auto Mode -vision_llm_router_settings: - routing_strategy: "usage-based-routing" - num_retries: 3 - allowed_fails: 3 - cooldown_time: 60 - -global_vision_llm_configs: - # Example: OpenAI GPT-4o (recommended for vision) - - id: -1001 - name: "Global GPT-4o Vision" - description: "OpenAI's GPT-4o with strong vision capabilities" - litellm_provider: "openai" - model_name: "gpt-4o" - api_key: "sk-your-openai-api-key-here" - api_base: "https://api.openai.com/v1" - rpm: 500 - tpm: 100000 - litellm_params: - temperature: 0.3 - max_tokens: 1000 - - # Example: Google Gemini 2.0 Flash - - id: -1002 - name: "Global Gemini 2.0 Flash" - description: "Google's fast vision model with large context" - litellm_provider: "gemini" - model_name: "gemini-2.0-flash" - api_key: "your-google-ai-api-key-here" - api_base: "https://generativelanguage.googleapis.com/v1beta" - rpm: 1000 - tpm: 200000 - litellm_params: - temperature: 0.3 - max_tokens: 1000 - - # Example: Anthropic Claude 3.5 Sonnet - - id: -1003 - name: "Global Claude 3.5 Sonnet Vision" - description: "Anthropic's Claude 3.5 Sonnet with vision support" - litellm_provider: "anthropic" - model_name: "claude-3-5-sonnet-20241022" - api_key: "sk-ant-your-anthropic-api-key-here" - api_base: "https://api.anthropic.com/v1" - rpm: 1000 - tpm: 100000 - litellm_params: - temperature: 0.3 - max_tokens: 1000 - - # Example: Azure OpenAI GPT-4o - # - id: -1004 - # name: "Global Azure GPT-4o Vision" - # description: "Azure-hosted GPT-4o for vision analysis" - # litellm_provider: "azure" - # model_name: "azure/gpt-4o-deployment" - # api_key: "your-azure-api-key-here" - # api_base: "https://your-resource.openai.azure.com" - # api_version: "2024-02-15-preview" - # rpm: 500 - # tpm: 100000 - # litellm_params: - # temperature: 0.3 - # max_tokens: 1000 - # base_model: "gpt-4o" - # Notes: # - ID 0 is reserved for "Auto" mode - uses LiteLLM Router for load balancing # - Use negative IDs to distinguish global models from BYOK/local DB models -# - IDs must be unique across chat, vision, and image generation configs -# - Suggested static ranges: chat -1..-999, vision -1001..-1999, image -2001..-2999 +# - IDs must be unique across chat and image generation configs +# - Suggested static ranges: chat -1..-999, image -2001..-2999 # - The 'api_key' field will not be exposed to users via API # - system_instructions: Custom prompt or empty string to use defaults # - use_default_system_instructions: true = use SURFSENSE_SYSTEM_INSTRUCTIONS when system_instructions is empty diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 728031fa0..38d0ffe33 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -198,81 +198,6 @@ class DocumentStatus: return None -class LiteLLMProvider(StrEnum): - """ - Enum for LLM providers supported by LiteLLM. - """ - - OPENAI = "OPENAI" - ANTHROPIC = "ANTHROPIC" - GOOGLE = "GOOGLE" - AZURE_OPENAI = "AZURE_OPENAI" - BEDROCK = "BEDROCK" - VERTEX_AI = "VERTEX_AI" - GROQ = "GROQ" - COHERE = "COHERE" - MISTRAL = "MISTRAL" - DEEPSEEK = "DEEPSEEK" - XAI = "XAI" - OPENROUTER = "OPENROUTER" - TOGETHER_AI = "TOGETHER_AI" - FIREWORKS_AI = "FIREWORKS_AI" - REPLICATE = "REPLICATE" - PERPLEXITY = "PERPLEXITY" - OLLAMA = "OLLAMA" - ALIBABA_QWEN = "ALIBABA_QWEN" - MOONSHOT = "MOONSHOT" - ZHIPU = "ZHIPU" - ANYSCALE = "ANYSCALE" - DEEPINFRA = "DEEPINFRA" - CEREBRAS = "CEREBRAS" - SAMBANOVA = "SAMBANOVA" - AI21 = "AI21" - CLOUDFLARE = "CLOUDFLARE" - DATABRICKS = "DATABRICKS" - COMETAPI = "COMETAPI" - HUGGINGFACE = "HUGGINGFACE" - GITHUB_MODELS = "GITHUB_MODELS" - MINIMAX = "MINIMAX" - CUSTOM = "CUSTOM" - - -class ImageGenProvider(StrEnum): - """ - Enum for image generation providers supported by LiteLLM. - This is a subset of LLM providers — only those that support image generation. - See: https://docs.litellm.ai/docs/image_generation#supported-providers - """ - - OPENAI = "OPENAI" - AZURE_OPENAI = "AZURE_OPENAI" - GOOGLE = "GOOGLE" # Google AI Studio - VERTEX_AI = "VERTEX_AI" - BEDROCK = "BEDROCK" # AWS Bedrock - RECRAFT = "RECRAFT" - OPENROUTER = "OPENROUTER" - XINFERENCE = "XINFERENCE" - NSCALE = "NSCALE" - - -class VisionProvider(StrEnum): - OPENAI = "OPENAI" - ANTHROPIC = "ANTHROPIC" - GOOGLE = "GOOGLE" - AZURE_OPENAI = "AZURE_OPENAI" - VERTEX_AI = "VERTEX_AI" - BEDROCK = "BEDROCK" - XAI = "XAI" - OPENROUTER = "OPENROUTER" - OLLAMA = "OLLAMA" - GROQ = "GROQ" - TOGETHER_AI = "TOGETHER_AI" - FIREWORKS_AI = "FIREWORKS_AI" - DEEPSEEK = "DEEPSEEK" - MISTRAL = "MISTRAL" - CUSTOM = "CUSTOM" - - class ConnectionScope(StrEnum): GLOBAL = "GLOBAL" SEARCH_SPACE = "SEARCH_SPACE" @@ -710,11 +635,11 @@ class NewChatThread(BaseModel, TimestampMixin): default=False, server_default="false", ) - # Auto (Fastest) model pin for this thread: concrete resolved global LLM + # Auto model pin for this thread: concrete resolved global LLM # config id. NULL means no pin; Auto will resolve on the next turn. # Single-writer invariant: only app.services.auto_model_pin_service sets # or clears this column (plus bulk clears when a search space's - # agent_llm_id changes). Unindexed: all reads are by primary key. + # chat_model_id changes). Unindexed: all reads are by primary key. pinned_llm_config_id = Column(Integer, nullable=True) # Surface metadata for first-party SurfSense and external chat threads. @@ -1686,75 +1611,6 @@ class Model(BaseModel, TimestampMixin): ) -class ImageGenerationConfig(BaseModel, TimestampMixin): - """ - Dedicated configuration table for image generation models. - - Separate from NewLLMConfig because image generation models don't need - system_instructions, citations_enabled, or use_default_system_instructions. - They only need provider credentials and model parameters. - """ - - __tablename__ = "image_generation_configs" - - name = Column(String(100), nullable=False, index=True) - description = Column(String(500), nullable=True) - - # Provider & model (uses ImageGenProvider, NOT LiteLLMProvider) - provider = Column(SQLAlchemyEnum(ImageGenProvider), nullable=False) - custom_provider = Column(String(100), nullable=True) - model_name = Column(String(100), nullable=False) - - # Credentials - api_key = Column(String, nullable=False) - api_base = Column(String(500), nullable=True) - api_version = Column(String(50), nullable=True) # Azure-specific - - # Additional litellm parameters - litellm_params = Column(JSON, nullable=True, default={}) - - # Relationships - search_space_id = Column( - Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False - ) - search_space = relationship( - "SearchSpace", back_populates="image_generation_configs" - ) - - # User who created this config - user_id = Column( - UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False - ) - user = relationship("User", back_populates="image_generation_configs") - - -class VisionLLMConfig(BaseModel, TimestampMixin): - __tablename__ = "vision_llm_configs" - - name = Column(String(100), nullable=False, index=True) - description = Column(String(500), nullable=True) - - provider = Column(SQLAlchemyEnum(VisionProvider), nullable=False) - custom_provider = Column(String(100), nullable=True) - model_name = Column(String(100), nullable=False) - - api_key = Column(String, nullable=False) - api_base = Column(String(500), nullable=True) - api_version = Column(String(50), nullable=True) - - litellm_params = Column(JSON, nullable=True, default={}) - - search_space_id = Column( - Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False - ) - search_space = relationship("SearchSpace", back_populates="vision_llm_configs") - - user_id = Column( - UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False - ) - user = relationship("User", back_populates="vision_llm_configs") - - class ImageGeneration(BaseModel, TimestampMixin): """ Stores image generation requests and results using litellm.aimage_generation(). @@ -1786,10 +1642,9 @@ class ImageGeneration(BaseModel, TimestampMixin): style = Column(String(50), nullable=True) # Model-specific style parameter response_format = Column(String(50), nullable=True) # "url" or "b64_json" - # Image generation config reference - # 0 = Auto mode (router), negative IDs = global configs from YAML, - # positive IDs = ImageGenerationConfig records in DB - image_generation_config_id = Column(Integer, nullable=True) + # Image generation model provenance. + # 0 = Auto mode, negative IDs = GLOBAL models, positive IDs = Model records. + image_gen_model_id = Column(Integer, nullable=True) # Response data (full litellm response as JSONB) — present on success response_data = Column(JSONB, nullable=True) @@ -1831,23 +1686,7 @@ class SearchSpace(BaseModel, TimestampMixin): shared_memory_md = Column(Text, nullable=True, server_default="") - # Search space-level LLM preferences (shared by all members) - # Note: ID values: - # - 0: Auto mode (uses LiteLLM Router for load balancing) - default for new search spaces - # - Negative IDs: Global configs from YAML - # - Positive IDs: Custom configs from DB (NewLLMConfig table) - agent_llm_id = Column( - Integer, nullable=True, default=0 - ) # For chat operations, defaults to Auto mode - image_generation_config_id = Column( - Integer, nullable=True, default=0 - ) # For image generation, defaults to Auto mode - vision_llm_config_id = Column( - Integer, nullable=True, default=0 - ) # For vision/screenshot analysis, defaults to Auto mode - - # New connection/model role bindings. These supersede the legacy config - # columns above without removing them in this PR. + # Connection/model role bindings. # Note: ID values preserve the existing convention: # - 0: Auto mode # - Negative IDs: Global virtual models from global_llm_config.yaml @@ -1931,24 +1770,6 @@ class SearchSpace(BaseModel, TimestampMixin): order_by="SearchSourceConnector.id", cascade="all, delete-orphan", ) - new_llm_configs = relationship( - "NewLLMConfig", - back_populates="search_space", - order_by="NewLLMConfig.id", - cascade="all, delete-orphan", - ) - image_generation_configs = relationship( - "ImageGenerationConfig", - back_populates="search_space", - order_by="ImageGenerationConfig.id", - cascade="all, delete-orphan", - ) - vision_llm_configs = relationship( - "VisionLLMConfig", - back_populates="search_space", - order_by="VisionLLMConfig.id", - cascade="all, delete-orphan", - ) connections = relationship( "Connection", back_populates="search_space", @@ -2057,64 +1878,6 @@ class SearchSourceConnector(BaseModel, TimestampMixin): documents = relationship("Document", back_populates="connector") -class NewLLMConfig(BaseModel, TimestampMixin): - """ - New LLM configuration table that combines model settings with prompt configuration. - - This table provides: - - LLM model configuration (provider, model_name, api_key, etc.) - - Configurable system instructions (defaults to SURFSENSE_SYSTEM_INSTRUCTIONS) - - Citation toggle (enable/disable citation instructions) - - Note: Tools instructions are built by get_tools_instructions(thread_visibility) (personal vs shared memory). - """ - - __tablename__ = "new_llm_configs" - - name = Column(String(100), nullable=False, index=True) - description = Column(String(500), nullable=True) - - # === LLM Model Configuration (from original LLMConfig, excluding 'language') === - # Provider from the enum - provider = Column(SQLAlchemyEnum(LiteLLMProvider), nullable=False) - # Custom provider name when provider is CUSTOM - custom_provider = Column(String(100), nullable=True) - # Just the model name without provider prefix - model_name = Column(String(100), nullable=False) - # API Key should be encrypted before storing - api_key = Column(String, nullable=False) - api_base = Column(String(500), nullable=True) - # For any other parameters that litellm supports - litellm_params = Column(JSON, nullable=True, default={}) - - # === Prompt Configuration === - # Configurable system instructions (defaults to SURFSENSE_SYSTEM_INSTRUCTIONS) - # Users can customize this from the UI - system_instructions = Column( - Text, - nullable=False, - default="", # Empty string means use default SURFSENSE_SYSTEM_INSTRUCTIONS - ) - # Whether to use the default system instructions when system_instructions is empty - use_default_system_instructions = Column(Boolean, nullable=False, default=True) - - # Citation toggle - when enabled, SURFSENSE_CITATION_INSTRUCTIONS is injected - # When disabled, an anti-citation prompt is injected instead - citations_enabled = Column(Boolean, nullable=False, default=True) - - # === Relationships === - search_space_id = Column( - Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False - ) - search_space = relationship("SearchSpace", back_populates="new_llm_configs") - - # User who created this config - user_id = Column( - UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False - ) - user = relationship("User", back_populates="new_llm_configs") - - class Log(BaseModel, TimestampMixin): __tablename__ = "logs" @@ -2481,25 +2244,6 @@ if config.AUTH_TYPE == "GOOGLE": passive_deletes=True, ) - # LLM configs created by this user - new_llm_configs = relationship( - "NewLLMConfig", - back_populates="user", - passive_deletes=True, - ) - - # Image generation configs created by this user - image_generation_configs = relationship( - "ImageGenerationConfig", - back_populates="user", - passive_deletes=True, - ) - - vision_llm_configs = relationship( - "VisionLLMConfig", - back_populates="user", - passive_deletes=True, - ) connections = relationship( "Connection", back_populates="user", @@ -2632,25 +2376,6 @@ else: passive_deletes=True, ) - # LLM configs created by this user - new_llm_configs = relationship( - "NewLLMConfig", - back_populates="user", - passive_deletes=True, - ) - - # Image generation configs created by this user - image_generation_configs = relationship( - "ImageGenerationConfig", - back_populates="user", - passive_deletes=True, - ) - - vision_llm_configs = relationship( - "VisionLLMConfig", - back_populates="user", - passive_deletes=True, - ) connections = relationship( "Connection", back_populates="user", diff --git a/surfsense_backend/app/prompts/default_system_instructions.py b/surfsense_backend/app/prompts/default_system_instructions.py index fd0a8e186..b968fc1f0 100644 --- a/surfsense_backend/app/prompts/default_system_instructions.py +++ b/surfsense_backend/app/prompts/default_system_instructions.py @@ -82,7 +82,7 @@ def build_configurable_system_prompt( *, model_name: str | None = None, ) -> str: - """Build a configurable SurfSense system prompt (NewLLMConfig path). + """Build a configurable SurfSense system prompt. See :func:`app.prompts.system_prompt_composer.composer.compose_system_prompt` for full parameter docs. @@ -104,7 +104,7 @@ def build_configurable_system_prompt( def get_default_system_instructions() -> str: """Return the default ```` block (no tools / citations). - Useful for populating the UI when seeding ``NewLLMConfig.system_instructions``. + Useful for populating the UI when editing custom system instructions. The output reflects the current fragment tree, not a baked-in constant. """ resolved_today = datetime.now(UTC).date().isoformat() diff --git a/surfsense_backend/app/prompts/system_prompt_composer/composer.py b/surfsense_backend/app/prompts/system_prompt_composer/composer.py index 3849af313..c639d4aa0 100644 --- a/surfsense_backend/app/prompts/system_prompt_composer/composer.py +++ b/surfsense_backend/app/prompts/system_prompt_composer/composer.py @@ -348,8 +348,7 @@ def compose_system_prompt( mcp_connector_tools: ``{server_name: [tool_names...]}`` to inject an explicit MCP routing block. custom_system_instructions: Free-form instructions that override - the default ```` block (legacy support - for ``NewLLMConfig.system_instructions``). + the default ```` block. use_default_system_instructions: When ``custom_system_instructions`` is empty/None, fall back to defaults (legacy semantics). citations_enabled: Include ``citations_on.md`` (true) or diff --git a/surfsense_backend/app/routes/__init__.py b/surfsense_backend/app/routes/__init__.py index f9f6b3d28..2b997cef5 100644 --- a/surfsense_backend/app/routes/__init__.py +++ b/surfsense_backend/app/routes/__init__.py @@ -47,7 +47,6 @@ from .model_connections_routes import router as model_connections_router from .memory_routes import router as memory_router from .model_list_routes import router as model_list_router from .new_chat_routes import router as new_chat_router -from .new_llm_config_routes import router as new_llm_config_router from .notes_routes import router as notes_router from .notion_add_connector_route import router as notion_add_connector_router from .obsidian_plugin_routes import router as obsidian_plugin_router @@ -64,7 +63,6 @@ from .stripe_routes import router as stripe_router from .team_memory_routes import router as team_memory_router from .teams_add_connector_route import router as teams_add_connector_router from .video_presentations_routes import router as video_presentations_router -from .vision_llm_routes import router as vision_llm_router from .youtube_routes import router as youtube_router router = APIRouter() @@ -99,7 +97,6 @@ router.include_router( ) # Video presentation status and streaming router.include_router(reports_router) # Report CRUD and multi-format export router.include_router(image_generation_router) # Image generation via litellm -router.include_router(vision_llm_router) # Vision LLM configs for screenshot analysis router.include_router(search_source_connectors_router) router.include_router(google_calendar_add_connector_router) router.include_router(google_gmail_add_connector_router) @@ -117,7 +114,6 @@ router.include_router(jira_add_connector_router) router.include_router(confluence_add_connector_router) router.include_router(clickup_add_connector_router) router.include_router(dropbox_add_connector_router) -router.include_router(new_llm_config_router) # LLM configs with prompt configuration router.include_router(model_connections_router) # Connection-centric model catalog router.include_router(model_list_router) # Dynamic model catalogue from OpenRouter router.include_router(logs_router) diff --git a/surfsense_backend/app/routes/image_generation_routes.py b/surfsense_backend/app/routes/image_generation_routes.py index 29a2b58bc..7e95d4dba 100644 --- a/surfsense_backend/app/routes/image_generation_routes.py +++ b/surfsense_backend/app/routes/image_generation_routes.py @@ -1,7 +1,5 @@ """ Image Generation routes: -- CRUD for ImageGenerationConfig (user-created image model configs) -- Global image gen configs endpoint (from YAML) - Image generation execution (calls litellm.aimage_generation()) - CRUD for ImageGeneration records (results) - Image serving endpoint (serves b64_json images from DB, protected by signed tokens) @@ -21,7 +19,6 @@ from sqlalchemy.orm import selectinload from app.config import config from app.db import ( ImageGeneration, - ImageGenerationConfig, Model, Permission, SearchSpace, @@ -30,14 +27,14 @@ from app.db import ( get_async_session, ) from app.schemas import ( - GlobalImageGenConfigRead, - ImageGenerationConfigCreate, - ImageGenerationConfigRead, - ImageGenerationConfigUpdate, ImageGenerationCreate, ImageGenerationListRead, ImageGenerationRead, ) +from app.services.auto_model_pin_service import ( + auto_model_candidates, + choose_auto_model_candidate, +) from app.services.billable_calls import ( DEFAULT_IMAGE_RESERVE_MICROS, QuotaInsufficientError, @@ -47,12 +44,8 @@ from app.services.image_gen_router_service import ( IMAGE_GEN_AUTO_MODE_ID, is_image_gen_auto_mode, ) -from app.services.auto_model_pin_service import ( - auto_model_candidates, - choose_auto_model_candidate, -) -from app.services.model_resolver import to_litellm from app.services.model_capabilities import has_capability +from app.services.model_resolver import to_litellm from app.users import current_active_user from app.utils.rbac import check_permission from app.utils.signed_image_urls import verify_image_token @@ -131,14 +124,14 @@ async def _execute_image_generation( Call litellm.aimage_generation() with the appropriate config. Resolution order: - 1. Explicit image_generation_config_id on the request - 2. Search space's image_generation_config_id preference + 1. Explicit image_gen_model_id on the request + 2. Search space's image_gen_model_id preference 3. Falls back to Auto mode if available """ - config_id = image_gen.image_generation_config_id + config_id = image_gen.image_gen_model_id if config_id is None: config_id = search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID - image_gen.image_generation_config_id = config_id + image_gen.image_gen_model_id = config_id # Build kwargs gen_kwargs = {} @@ -163,7 +156,7 @@ async def _execute_image_generation( if not candidates: raise ValueError("No image-generation models are available for Auto mode") config_id = int(choose_auto_model_candidate(candidates, search_space.id)["id"]) - image_gen.image_generation_config_id = config_id + image_gen.image_gen_model_id = config_id if config_id < 0: global_model = _get_global_model(config_id) @@ -228,266 +221,6 @@ async def _execute_image_generation( image_gen.model = hidden["model"] -# ============================================================================= -# Global Image Generation Configs (from YAML) -# ============================================================================= - - -@router.get( - "/global-image-generation-configs", - response_model=list[GlobalImageGenConfigRead], -) -async def get_global_image_gen_configs( - user: User = Depends(current_active_user), -): - """Get all global image generation configs. API keys are hidden.""" - try: - global_configs = config.GLOBAL_IMAGE_GEN_CONFIGS - safe_configs = [] - - if global_configs and len(global_configs) > 0: - safe_configs.append( - { - "id": 0, - "name": "Auto (Fastest)", - "description": "Automatically routes across available image generation providers.", - "provider": "AUTO", - "custom_provider": None, - "model_name": "auto", - "api_base": None, - "api_version": None, - "litellm_params": {}, - "is_global": True, - "is_auto_mode": True, - # Auto mode currently treated as free until per-deployment - # billing-tier surfacing lands (see _resolve_billing_for_image_gen). - "billing_tier": "free", - "is_premium": False, - } - ) - - for cfg in global_configs: - billing_tier = str(cfg.get("billing_tier", "free")).lower() - safe_configs.append( - { - "id": cfg.get("id"), - "name": cfg.get("name"), - "description": cfg.get("description"), - "provider": cfg.get("provider") or cfg.get("litellm_provider"), - "custom_provider": cfg.get("custom_provider"), - "model_name": cfg.get("model_name"), - "api_base": cfg.get("api_base") or None, - "api_version": cfg.get("api_version") or None, - "litellm_params": cfg.get("litellm_params", {}), - "is_global": True, - "billing_tier": billing_tier, - # Mirror chat (``new_llm_config_routes``) so the new-chat - # selector's premium badge logic keys off the same - # field across chat / image / vision tabs. - "is_premium": billing_tier == "premium", - "quota_reserve_micros": cfg.get("quota_reserve_micros"), - } - ) - - return safe_configs - except Exception as e: - logger.exception("Failed to fetch global image generation configs") - raise HTTPException( - status_code=500, detail=f"Failed to fetch configs: {e!s}" - ) from e - - -# ============================================================================= -# ImageGenerationConfig CRUD -# ============================================================================= - - -@router.post("/image-generation-configs", response_model=ImageGenerationConfigRead) -async def create_image_gen_config( - config_data: ImageGenerationConfigCreate, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """Create a new image generation config for a search space.""" - try: - await check_permission( - session, - user, - config_data.search_space_id, - Permission.IMAGE_GENERATIONS_CREATE.value, - "You don't have permission to create image generation configs in this search space", - ) - - db_config = ImageGenerationConfig(**config_data.model_dump(), user_id=user.id) - session.add(db_config) - await session.commit() - await session.refresh(db_config) - return db_config - - except HTTPException: - raise - except Exception as e: - await session.rollback() - logger.exception("Failed to create ImageGenerationConfig") - raise HTTPException( - status_code=500, detail=f"Failed to create config: {e!s}" - ) from e - - -@router.get("/image-generation-configs", response_model=list[ImageGenerationConfigRead]) -async def list_image_gen_configs( - search_space_id: int, - skip: int = 0, - limit: int = 100, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """List image generation configs for a search space.""" - try: - await check_permission( - session, - user, - search_space_id, - Permission.IMAGE_GENERATIONS_READ.value, - "You don't have permission to view image generation configs in this search space", - ) - - result = await session.execute( - select(ImageGenerationConfig) - .filter(ImageGenerationConfig.search_space_id == search_space_id) - .order_by(ImageGenerationConfig.created_at.desc()) - .offset(skip) - .limit(limit) - ) - return result.scalars().all() - - except HTTPException: - raise - except Exception as e: - logger.exception("Failed to list ImageGenerationConfigs") - raise HTTPException( - status_code=500, detail=f"Failed to fetch configs: {e!s}" - ) from e - - -@router.get( - "/image-generation-configs/{config_id}", response_model=ImageGenerationConfigRead -) -async def get_image_gen_config( - config_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """Get a specific image generation config by ID.""" - try: - result = await session.execute( - select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id) - ) - db_config = result.scalars().first() - if not db_config: - raise HTTPException(status_code=404, detail="Config not found") - - await check_permission( - session, - user, - db_config.search_space_id, - Permission.IMAGE_GENERATIONS_READ.value, - "You don't have permission to view image generation configs in this search space", - ) - return db_config - - except HTTPException: - raise - except Exception as e: - logger.exception("Failed to get ImageGenerationConfig") - raise HTTPException( - status_code=500, detail=f"Failed to fetch config: {e!s}" - ) from e - - -@router.put( - "/image-generation-configs/{config_id}", response_model=ImageGenerationConfigRead -) -async def update_image_gen_config( - config_id: int, - update_data: ImageGenerationConfigUpdate, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """Update an existing image generation config.""" - try: - result = await session.execute( - select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id) - ) - db_config = result.scalars().first() - if not db_config: - raise HTTPException(status_code=404, detail="Config not found") - - await check_permission( - session, - user, - db_config.search_space_id, - Permission.IMAGE_GENERATIONS_CREATE.value, - "You don't have permission to update image generation configs in this search space", - ) - - for key, value in update_data.model_dump(exclude_unset=True).items(): - setattr(db_config, key, value) - - await session.commit() - await session.refresh(db_config) - return db_config - - except HTTPException: - raise - except Exception as e: - await session.rollback() - logger.exception("Failed to update ImageGenerationConfig") - raise HTTPException( - status_code=500, detail=f"Failed to update config: {e!s}" - ) from e - - -@router.delete("/image-generation-configs/{config_id}", response_model=dict) -async def delete_image_gen_config( - config_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """Delete an image generation config.""" - try: - result = await session.execute( - select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id) - ) - db_config = result.scalars().first() - if not db_config: - raise HTTPException(status_code=404, detail="Config not found") - - await check_permission( - session, - user, - db_config.search_space_id, - Permission.IMAGE_GENERATIONS_DELETE.value, - "You don't have permission to delete image generation configs in this search space", - ) - - await session.delete(db_config) - await session.commit() - return { - "message": "Image generation config deleted successfully", - "id": config_id, - } - - except HTTPException: - raise - except Exception as e: - await session.rollback() - logger.exception("Failed to delete ImageGenerationConfig") - raise HTTPException( - status_code=500, detail=f"Failed to delete config: {e!s}" - ) from e - - # ============================================================================= # Image Generation Execution + Results CRUD # ============================================================================= @@ -536,7 +269,7 @@ async def create_image_generation( raise HTTPException(status_code=404, detail="Search space not found") billing_tier, base_model, reserve_micros = await _resolve_billing_for_image_gen( - session, data.image_generation_config_id, search_space + session, data.image_gen_model_id, search_space ) # billable_call runs OUTSIDE the inner try/except so QuotaInsufficientError @@ -562,7 +295,7 @@ async def create_image_generation( size=data.size, style=data.style, response_format=data.response_format, - image_generation_config_id=data.image_generation_config_id, + image_gen_model_id=data.image_gen_model_id, search_space_id=data.search_space_id, created_by_id=user.id, ) diff --git a/surfsense_backend/app/routes/model_connections_routes.py b/surfsense_backend/app/routes/model_connections_routes.py index 1fd2e1e8e..90d246c54 100644 --- a/surfsense_backend/app/routes/model_connections_routes.py +++ b/surfsense_backend/app/routes/model_connections_routes.py @@ -11,6 +11,7 @@ from app.db import ( ConnectionScope, Model, ModelSource, + NewChatThread, Permission, SearchSpace, User, @@ -708,12 +709,26 @@ async def update_model_roles( search_space = await _get_search_space(session, search_space_id) updates = data.model_dump(exclude_unset=True) if "chat_model_id" in updates: - search_space.chat_model_id = await _validate_role_model_id( + previous_chat_model_id = search_space.chat_model_id + next_chat_model_id = await _validate_role_model_id( session, search_space_id=search_space_id, model_id=updates["chat_model_id"], capability="chat", ) + search_space.chat_model_id = next_chat_model_id + if next_chat_model_id != previous_chat_model_id: + await session.execute( + update(NewChatThread) + .where(NewChatThread.search_space_id == search_space_id) + .values(pinned_llm_config_id=None) + ) + logger.info( + "Cleared auto model pins for search_space_id=%s after chat_model_id change (%s -> %s)", + search_space_id, + previous_chat_model_id, + next_chat_model_id, + ) if "vision_model_id" in updates: search_space.vision_model_id = await _validate_role_model_id( session, diff --git a/surfsense_backend/app/routes/new_llm_config_routes.py b/surfsense_backend/app/routes/new_llm_config_routes.py deleted file mode 100644 index adba5b5ae..000000000 --- a/surfsense_backend/app/routes/new_llm_config_routes.py +++ /dev/null @@ -1,480 +0,0 @@ -""" -API routes for NewLLMConfig CRUD operations. - -NewLLMConfig combines model settings with prompt configuration: -- LLM provider, model, API key, etc. -- Configurable system instructions -- Citation toggle -""" - -import logging - -from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.future import select - -from app.config import config -from app.db import ( - NewLLMConfig, - Permission, - User, - get_async_session, -) -from app.prompts.default_system_instructions import get_default_system_instructions -from app.schemas import ( - DefaultSystemInstructionsResponse, - GlobalNewLLMConfigRead, - NewLLMConfigCreate, - NewLLMConfigRead, - NewLLMConfigUpdate, -) -from app.services.llm_service import validate_llm_config -from app.services.provider_capabilities import derive_supports_image_input -from app.users import current_active_user -from app.utils.rbac import check_permission - -router = APIRouter() -logger = logging.getLogger(__name__) - - -def _serialize_byok_config(config: NewLLMConfig) -> NewLLMConfigRead: - """Augment a BYOK chat config row with the derived ``supports_image_input``. - - There is no DB column for ``supports_image_input`` — the value is - resolved at the API boundary from LiteLLM's authoritative model map - (default-allow on unknown). Returning ``NewLLMConfigRead`` here keeps - the response shape consistent across list / detail / create / update - endpoints without having to remember to set the field at every call - site. - """ - provider_value = ( - config.provider.value - if hasattr(config.provider, "value") - else str(config.provider) - ) - litellm_params = config.litellm_params or {} - base_model = ( - litellm_params.get("base_model") if isinstance(litellm_params, dict) else None - ) - supports_image_input = derive_supports_image_input( - provider=provider_value.lower(), - model_name=config.model_name, - base_model=base_model, - custom_provider=config.custom_provider, - ) - # ``model_validate`` runs the Pydantic conversion using the ORM - # attribute access path enabled by ``ConfigDict(from_attributes=True)``, - # then we layer the derived field on. ``model_copy(update=...)`` keeps - # the surface immutable from the caller's perspective. - base_read = NewLLMConfigRead.model_validate(config) - return base_read.model_copy(update={"supports_image_input": supports_image_input}) - - -# ============================================================================= -# Global Configs Routes -# ============================================================================= - - -@router.get("/global-new-llm-configs", response_model=list[GlobalNewLLMConfigRead]) -async def get_global_new_llm_configs( - user: User = Depends(current_active_user), -): - """ - Get all available global NewLLMConfig configurations. - These are pre-configured by the system administrator and available to all users. - API keys are not exposed through this endpoint. - - Includes: - - Auto mode (ID 0): Uses LiteLLM Router for automatic load balancing - - Global configs (negative IDs): Individual pre-configured LLM providers - """ - try: - global_configs = config.GLOBAL_LLM_CONFIGS - safe_configs = [] - - # Only include Auto mode if there are actual global configs to route to - # Auto mode requires at least one global config with valid API key - if global_configs and len(global_configs) > 0: - safe_configs.append( - { - "id": 0, - "name": "Auto (Fastest)", - "description": "Automatically routes requests across available LLM providers for optimal performance and rate limit handling. Recommended for most users.", - "provider": "AUTO", - "custom_provider": None, - "model_name": "auto", - "api_base": None, - "litellm_params": {}, - "system_instructions": "", - "use_default_system_instructions": True, - "citations_enabled": True, - "is_global": True, - "is_auto_mode": True, - "billing_tier": "free", - "is_premium": False, - "anonymous_enabled": False, - "seo_enabled": False, - "seo_slug": None, - "seo_title": None, - "seo_description": None, - "quota_reserve_tokens": None, - # Auto routes across the configured pool, which usually - # includes at least one vision-capable deployment, so - # treat Auto as image-capable. The router itself will - # still pick a vision-capable deployment for messages - # carrying image_url blocks (LiteLLM Router falls back - # on ``404`` per its ``allowed_fails`` policy). - "supports_image_input": True, - } - ) - - # Add individual global configs - for cfg in global_configs: - # Capability resolution: explicit value (YAML override or OR - # `_supports_image_input(model)` payload baked in by the - # OpenRouter integration service) wins. Fall back to the - # LiteLLM-driven helper which default-allows on unknown so - # we don't hide vision-capable models that happen to lack a - # YAML annotation. The streaming task safety net is the - # only place a False ever blocks. - if "supports_image_input" in cfg: - supports_image_input = bool(cfg.get("supports_image_input")) - else: - cfg_litellm_params = cfg.get("litellm_params") or {} - cfg_base_model = ( - cfg_litellm_params.get("base_model") - if isinstance(cfg_litellm_params, dict) - else None - ) - supports_image_input = derive_supports_image_input( - provider=cfg.get("provider") or cfg.get("litellm_provider"), - model_name=cfg.get("model_name"), - base_model=cfg_base_model, - custom_provider=cfg.get("custom_provider"), - ) - - safe_config = { - "id": cfg.get("id"), - "name": cfg.get("name"), - "description": cfg.get("description"), - "provider": cfg.get("provider") or cfg.get("litellm_provider"), - "custom_provider": cfg.get("custom_provider"), - "model_name": cfg.get("model_name"), - "api_base": cfg.get("api_base") or None, - "litellm_params": cfg.get("litellm_params", {}), - # New prompt configuration fields - "system_instructions": cfg.get("system_instructions", ""), - "use_default_system_instructions": cfg.get( - "use_default_system_instructions", True - ), - "citations_enabled": cfg.get("citations_enabled", True), - "is_global": True, - "billing_tier": cfg.get("billing_tier", "free"), - "is_premium": cfg.get("billing_tier", "free") == "premium", - "anonymous_enabled": cfg.get("anonymous_enabled", False), - "seo_enabled": cfg.get("seo_enabled", False), - "seo_slug": cfg.get("seo_slug"), - "seo_title": cfg.get("seo_title"), - "seo_description": cfg.get("seo_description"), - "quota_reserve_tokens": cfg.get("quota_reserve_tokens"), - "supports_image_input": supports_image_input, - } - safe_configs.append(safe_config) - - return safe_configs - except Exception as e: - logger.exception("Failed to fetch global NewLLMConfigs") - raise HTTPException( - status_code=500, detail=f"Failed to fetch global configurations: {e!s}" - ) from e - - -# ============================================================================= -# CRUD Routes -# ============================================================================= - - -@router.post("/new-llm-configs", response_model=NewLLMConfigRead) -async def create_new_llm_config( - config_data: NewLLMConfigCreate, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """ - Create a new NewLLMConfig for a search space. - Requires LLM_CONFIGS_CREATE permission. - """ - try: - # Verify user has permission - await check_permission( - session, - user, - config_data.search_space_id, - Permission.LLM_CONFIGS_CREATE.value, - "You don't have permission to create LLM configurations in this search space", - ) - - # Validate the LLM configuration by making a test API call - is_valid, error_message = await validate_llm_config( - provider=config_data.provider.value, - model_name=config_data.model_name, - api_key=config_data.api_key, - api_base=config_data.api_base, - custom_provider=config_data.custom_provider, - litellm_params=config_data.litellm_params, - ) - - if not is_valid: - raise HTTPException( - status_code=400, - detail=f"Invalid LLM configuration: {error_message}", - ) - - # Create the config with user association - db_config = NewLLMConfig(**config_data.model_dump(), user_id=user.id) - session.add(db_config) - await session.commit() - await session.refresh(db_config) - - return _serialize_byok_config(db_config) - - except HTTPException: - raise - except Exception as e: - await session.rollback() - logger.exception("Failed to create NewLLMConfig") - raise HTTPException( - status_code=500, detail=f"Failed to create configuration: {e!s}" - ) from e - - -@router.get("/new-llm-configs", response_model=list[NewLLMConfigRead]) -async def list_new_llm_configs( - search_space_id: int, - skip: int = 0, - limit: int = 100, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """ - Get all NewLLMConfigs for a search space. - Requires LLM_CONFIGS_READ permission. - """ - try: - # Verify user has permission - await check_permission( - session, - user, - search_space_id, - Permission.LLM_CONFIGS_READ.value, - "You don't have permission to view LLM configurations in this search space", - ) - - result = await session.execute( - select(NewLLMConfig) - .filter(NewLLMConfig.search_space_id == search_space_id) - .order_by(NewLLMConfig.created_at.desc()) - .offset(skip) - .limit(limit) - ) - - return [_serialize_byok_config(cfg) for cfg in result.scalars().all()] - - except HTTPException: - raise - except Exception as e: - logger.exception("Failed to list NewLLMConfigs") - raise HTTPException( - status_code=500, detail=f"Failed to fetch configurations: {e!s}" - ) from e - - -@router.get( - "/new-llm-configs/default-system-instructions", - response_model=DefaultSystemInstructionsResponse, -) -async def get_default_system_instructions_endpoint( - user: User = Depends(current_active_user), -): - """ - Get the default SURFSENSE_SYSTEM_INSTRUCTIONS template. - Useful for pre-populating the UI when creating a new configuration. - """ - return DefaultSystemInstructionsResponse( - default_system_instructions=get_default_system_instructions() - ) - - -@router.get("/new-llm-configs/{config_id}", response_model=NewLLMConfigRead) -async def get_new_llm_config( - config_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """ - Get a specific NewLLMConfig by ID. - Requires LLM_CONFIGS_READ permission. - """ - try: - result = await session.execute( - select(NewLLMConfig).filter(NewLLMConfig.id == config_id) - ) - config = result.scalars().first() - - if not config: - raise HTTPException(status_code=404, detail="Configuration not found") - - # Verify user has permission - await check_permission( - session, - user, - config.search_space_id, - Permission.LLM_CONFIGS_READ.value, - "You don't have permission to view LLM configurations in this search space", - ) - - return _serialize_byok_config(config) - - except HTTPException: - raise - except Exception as e: - logger.exception("Failed to get NewLLMConfig") - raise HTTPException( - status_code=500, detail=f"Failed to fetch configuration: {e!s}" - ) from e - - -@router.put("/new-llm-configs/{config_id}", response_model=NewLLMConfigRead) -async def update_new_llm_config( - config_id: int, - update_data: NewLLMConfigUpdate, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """ - Update an existing NewLLMConfig. - Requires LLM_CONFIGS_UPDATE permission. - """ - try: - result = await session.execute( - select(NewLLMConfig).filter(NewLLMConfig.id == config_id) - ) - config = result.scalars().first() - - if not config: - raise HTTPException(status_code=404, detail="Configuration not found") - - # Verify user has permission - await check_permission( - session, - user, - config.search_space_id, - Permission.LLM_CONFIGS_UPDATE.value, - "You don't have permission to update LLM configurations in this search space", - ) - - update_dict = update_data.model_dump(exclude_unset=True) - - # If updating LLM settings, validate them - if any( - key in update_dict - for key in [ - "provider", - "model_name", - "api_key", - "api_base", - "custom_provider", - "litellm_params", - ] - ): - # Build the validation config from existing + updates - validation_config = { - "provider": update_dict.get("provider", config.provider).value - if hasattr(update_dict.get("provider", config.provider), "value") - else update_dict.get("provider", config.provider.value), - "model_name": update_dict.get("model_name", config.model_name), - "api_key": update_dict.get("api_key", config.api_key), - "api_base": update_dict.get("api_base", config.api_base), - "custom_provider": update_dict.get( - "custom_provider", config.custom_provider - ), - "litellm_params": update_dict.get( - "litellm_params", config.litellm_params - ), - } - - is_valid, error_message = await validate_llm_config( - provider=validation_config["provider"], - model_name=validation_config["model_name"], - api_key=validation_config["api_key"], - api_base=validation_config["api_base"], - custom_provider=validation_config["custom_provider"], - litellm_params=validation_config["litellm_params"], - ) - - if not is_valid: - raise HTTPException( - status_code=400, - detail=f"Invalid LLM configuration: {error_message}", - ) - - # Apply updates - for key, value in update_dict.items(): - setattr(config, key, value) - - await session.commit() - await session.refresh(config) - - return _serialize_byok_config(config) - - except HTTPException: - raise - except Exception as e: - await session.rollback() - logger.exception("Failed to update NewLLMConfig") - raise HTTPException( - status_code=500, detail=f"Failed to update configuration: {e!s}" - ) from e - - -@router.delete("/new-llm-configs/{config_id}", response_model=dict) -async def delete_new_llm_config( - config_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """ - Delete a NewLLMConfig. - Requires LLM_CONFIGS_DELETE permission. - """ - try: - result = await session.execute( - select(NewLLMConfig).filter(NewLLMConfig.id == config_id) - ) - config = result.scalars().first() - - if not config: - raise HTTPException(status_code=404, detail="Configuration not found") - - # Verify user has permission - await check_permission( - session, - user, - config.search_space_id, - Permission.LLM_CONFIGS_DELETE.value, - "You don't have permission to delete LLM configurations in this search space", - ) - - await session.delete(config) - await session.commit() - - return {"message": "Configuration deleted successfully", "id": config_id} - - except HTTPException: - raise - except Exception as e: - await session.rollback() - logger.exception("Failed to delete NewLLMConfig") - raise HTTPException( - status_code=500, detail=f"Failed to delete configuration: {e!s}" - ) from e diff --git a/surfsense_backend/app/routes/search_spaces_routes.py b/surfsense_backend/app/routes/search_spaces_routes.py index 7c5fbf28b..592a9dd0e 100644 --- a/surfsense_backend/app/routes/search_spaces_routes.py +++ b/surfsense_backend/app/routes/search_spaces_routes.py @@ -1,27 +1,20 @@ import logging from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy import func, update +from sqlalchemy import func from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from app.config import config from app.db import ( - ImageGenerationConfig, - NewChatThread, - NewLLMConfig, Permission, SearchSpace, SearchSpaceMembership, SearchSpaceRole, User, - VisionLLMConfig, get_async_session, get_default_roles_config, ) from app.schemas import ( - LLMPreferencesRead, - LLMPreferencesUpdate, SearchSpaceCreate, SearchSpaceRead, SearchSpaceUpdate, @@ -377,357 +370,6 @@ async def delete_search_space( ) from e -# ============================================================================= -# LLM Preferences Routes -# ============================================================================= - - -async def _get_llm_config_by_id( - session: AsyncSession, config_id: int | None -) -> dict | None: - """ - Get an LLM config by ID as a dictionary. Returns database config for positive IDs, - global config for negative IDs, Auto mode config for ID 0, or None if ID is None. - """ - if config_id is None: - return None - - # Auto mode (ID 0) - uses LiteLLM Router for load balancing - if config_id == 0: - return { - "id": 0, - "name": "Auto (Fastest)", - "description": "Automatically routes requests across available LLM providers for optimal performance and rate limit handling", - "provider": "AUTO", - "custom_provider": None, - "model_name": "auto", - "api_base": None, - "litellm_params": {}, - "system_instructions": "", - "use_default_system_instructions": True, - "citations_enabled": True, - "is_global": True, - "is_auto_mode": True, - } - - if config_id < 0: - # Global config - find from YAML - global_configs = config.GLOBAL_LLM_CONFIGS - for cfg in global_configs: - if cfg.get("id") == config_id: - return { - "id": cfg.get("id"), - "name": cfg.get("name"), - "description": cfg.get("description"), - "provider": cfg.get("provider") or cfg.get("litellm_provider"), - "custom_provider": cfg.get("custom_provider"), - "model_name": cfg.get("model_name"), - "api_base": cfg.get("api_base"), - "litellm_params": cfg.get("litellm_params", {}), - "system_instructions": cfg.get("system_instructions", ""), - "use_default_system_instructions": cfg.get( - "use_default_system_instructions", True - ), - "citations_enabled": cfg.get("citations_enabled", True), - "is_global": True, - } - return None - else: - # Database config - convert to dict - result = await session.execute( - select(NewLLMConfig).filter(NewLLMConfig.id == config_id) - ) - db_config = result.scalars().first() - if db_config: - return { - "id": db_config.id, - "name": db_config.name, - "description": db_config.description, - "provider": db_config.provider.value if db_config.provider else None, - "custom_provider": db_config.custom_provider, - "model_name": db_config.model_name, - "api_key": db_config.api_key, - "api_base": db_config.api_base, - "litellm_params": db_config.litellm_params or {}, - "system_instructions": db_config.system_instructions or "", - "use_default_system_instructions": db_config.use_default_system_instructions, - "citations_enabled": db_config.citations_enabled, - "created_at": db_config.created_at.isoformat() - if db_config.created_at - else None, - "search_space_id": db_config.search_space_id, - } - return None - - -async def _get_image_gen_config_by_id( - session: AsyncSession, config_id: int | None -) -> dict | None: - """ - Get an image generation config by ID as a dictionary. - Returns Auto mode for ID 0, global config for negative IDs, - DB ImageGenerationConfig for positive IDs, or None. - """ - if config_id is None: - return None - - if config_id == 0: - return { - "id": 0, - "name": "Auto (Fastest)", - "description": "Automatically routes requests across available image generation providers", - "provider": "AUTO", - "model_name": "auto", - "is_global": True, - "is_auto_mode": True, - "billing_tier": "free", - } - - if config_id < 0: - for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS: - if cfg.get("id") == config_id: - return { - "id": cfg.get("id"), - "name": cfg.get("name"), - "description": cfg.get("description"), - "provider": cfg.get("provider") or cfg.get("litellm_provider"), - "custom_provider": cfg.get("custom_provider"), - "model_name": cfg.get("model_name"), - "api_base": cfg.get("api_base") or None, - "api_version": cfg.get("api_version") or None, - "litellm_params": cfg.get("litellm_params", {}), - "is_global": True, - "billing_tier": cfg.get("billing_tier", "free"), - } - return None - - # Positive ID: query ImageGenerationConfig table - result = await session.execute( - select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id) - ) - db_config = result.scalars().first() - if db_config: - return { - "id": db_config.id, - "name": db_config.name, - "description": db_config.description, - "provider": db_config.provider.value if db_config.provider else None, - "custom_provider": db_config.custom_provider, - "model_name": db_config.model_name, - "api_base": db_config.api_base, - "api_version": db_config.api_version, - "litellm_params": db_config.litellm_params or {}, - "created_at": db_config.created_at.isoformat() - if db_config.created_at - else None, - "search_space_id": db_config.search_space_id, - } - return None - - -async def _get_vision_llm_config_by_id( - session: AsyncSession, config_id: int | None -) -> dict | None: - if config_id is None: - return None - - if config_id == 0: - return { - "id": 0, - "name": "Auto (Fastest)", - "description": "Automatically routes requests across available vision LLM providers", - "provider": "AUTO", - "model_name": "auto", - "is_global": True, - "is_auto_mode": True, - "billing_tier": "free", - } - - if config_id < 0: - for cfg in config.GLOBAL_VISION_LLM_CONFIGS: - if cfg.get("id") == config_id: - return { - "id": cfg.get("id"), - "name": cfg.get("name"), - "description": cfg.get("description"), - "provider": cfg.get("provider") or cfg.get("litellm_provider"), - "custom_provider": cfg.get("custom_provider"), - "model_name": cfg.get("model_name"), - "api_base": cfg.get("api_base") or None, - "api_version": cfg.get("api_version") or None, - "litellm_params": cfg.get("litellm_params", {}), - "is_global": True, - "billing_tier": cfg.get("billing_tier", "free"), - } - return None - - result = await session.execute( - select(VisionLLMConfig).filter(VisionLLMConfig.id == config_id) - ) - db_config = result.scalars().first() - if db_config: - return { - "id": db_config.id, - "name": db_config.name, - "description": db_config.description, - "provider": db_config.provider.value if db_config.provider else None, - "custom_provider": db_config.custom_provider, - "model_name": db_config.model_name, - "api_base": db_config.api_base, - "api_version": db_config.api_version, - "litellm_params": db_config.litellm_params or {}, - "created_at": db_config.created_at.isoformat() - if db_config.created_at - else None, - "search_space_id": db_config.search_space_id, - } - return None - - -@router.get( - "/search-spaces/{search_space_id}/llm-preferences", - response_model=LLMPreferencesRead, -) -async def get_llm_preferences( - search_space_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """ - Get LLM preferences (role assignments) for a search space. - Requires LLM_CONFIGS_READ permission. - """ - try: - # Check permission - await check_permission( - session, - user, - search_space_id, - Permission.LLM_CONFIGS_READ.value, - "You don't have permission to view LLM preferences", - ) - - result = await session.execute( - select(SearchSpace).filter(SearchSpace.id == search_space_id) - ) - search_space = result.scalars().first() - - if not search_space: - raise HTTPException(status_code=404, detail="Search space not found") - - # Get full config objects for each role - agent_llm = await _get_llm_config_by_id(session, search_space.agent_llm_id) - image_generation_config = await _get_image_gen_config_by_id( - session, search_space.image_generation_config_id - ) - vision_llm_config = await _get_vision_llm_config_by_id( - session, search_space.vision_llm_config_id - ) - - return LLMPreferencesRead( - agent_llm_id=search_space.agent_llm_id, - image_generation_config_id=search_space.image_generation_config_id, - vision_llm_config_id=search_space.vision_llm_config_id, - agent_llm=agent_llm, - image_generation_config=image_generation_config, - vision_llm_config=vision_llm_config, - ) - - except HTTPException: - raise - except Exception as e: - logger.exception("Failed to get LLM preferences") - raise HTTPException( - status_code=500, detail=f"Failed to get LLM preferences: {e!s}" - ) from e - - -@router.put( - "/search-spaces/{search_space_id}/llm-preferences", - response_model=LLMPreferencesRead, -) -async def update_llm_preferences( - search_space_id: int, - preferences: LLMPreferencesUpdate, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """ - Update LLM preferences (role assignments) for a search space. - Requires LLM_CONFIGS_UPDATE permission. - """ - try: - # Check permission - await check_permission( - session, - user, - search_space_id, - Permission.LLM_CONFIGS_UPDATE.value, - "You don't have permission to update LLM preferences", - ) - - result = await session.execute( - select(SearchSpace).filter(SearchSpace.id == search_space_id) - ) - search_space = result.scalars().first() - - if not search_space: - raise HTTPException(status_code=404, detail="Search space not found") - - # Update preferences - update_data = preferences.model_dump(exclude_unset=True) - previous_agent_llm_id = search_space.agent_llm_id - for key, value in update_data.items(): - setattr(search_space, key, value) - - agent_llm_changed = ( - "agent_llm_id" in update_data - and update_data["agent_llm_id"] != previous_agent_llm_id - ) - if agent_llm_changed: - await session.execute( - update(NewChatThread) - .where(NewChatThread.search_space_id == search_space_id) - .values(pinned_llm_config_id=None) - ) - logger.info( - "Cleared auto model pins for search_space_id=%s after agent_llm_id change (%s -> %s)", - search_space_id, - previous_agent_llm_id, - update_data["agent_llm_id"], - ) - - await session.commit() - await session.refresh(search_space) - - # Get full config objects for response - agent_llm = await _get_llm_config_by_id(session, search_space.agent_llm_id) - image_generation_config = await _get_image_gen_config_by_id( - session, search_space.image_generation_config_id - ) - vision_llm_config = await _get_vision_llm_config_by_id( - session, search_space.vision_llm_config_id - ) - - return LLMPreferencesRead( - agent_llm_id=search_space.agent_llm_id, - image_generation_config_id=search_space.image_generation_config_id, - vision_llm_config_id=search_space.vision_llm_config_id, - agent_llm=agent_llm, - image_generation_config=image_generation_config, - vision_llm_config=vision_llm_config, - ) - - except HTTPException: - raise - except Exception as e: - await session.rollback() - logger.exception("Failed to update LLM preferences") - raise HTTPException( - status_code=500, detail=f"Failed to update LLM preferences: {e!s}" - ) from e - - @router.get("/searchspaces/{search_space_id}/snapshots") async def list_search_space_snapshots( search_space_id: int, diff --git a/surfsense_backend/app/routes/vision_llm_routes.py b/surfsense_backend/app/routes/vision_llm_routes.py deleted file mode 100644 index b93d25b9c..000000000 --- a/surfsense_backend/app/routes/vision_llm_routes.py +++ /dev/null @@ -1,304 +0,0 @@ -import logging - -from fastapi import APIRouter, Depends, HTTPException -from pydantic import BaseModel -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession - -from app.config import config -from app.db import ( - Permission, - User, - VisionLLMConfig, - get_async_session, -) -from app.schemas import ( - GlobalVisionLLMConfigRead, - VisionLLMConfigCreate, - VisionLLMConfigRead, - VisionLLMConfigUpdate, -) -from app.services.vision_model_list_service import get_vision_model_list -from app.users import current_active_user -from app.utils.rbac import check_permission - -router = APIRouter() -logger = logging.getLogger(__name__) - - -# ============================================================================= -# Vision Model Catalogue (from OpenRouter, filtered for image-input models) -# ============================================================================= - - -class VisionModelListItem(BaseModel): - value: str - label: str - provider: str - context_window: str | None = None - - -@router.get("/vision-models", response_model=list[VisionModelListItem]) -async def list_vision_models( - user: User = Depends(current_active_user), -): - """Return vision-capable models sourced from OpenRouter (filtered by image input).""" - try: - return await get_vision_model_list() - except Exception as e: - logger.exception("Failed to fetch vision model list") - raise HTTPException( - status_code=500, detail=f"Failed to fetch vision model list: {e!s}" - ) from e - - -# ============================================================================= -# Global Vision LLM Configs (from YAML) -# ============================================================================= - - -@router.get( - "/global-vision-llm-configs", - response_model=list[GlobalVisionLLMConfigRead], -) -async def get_global_vision_llm_configs( - user: User = Depends(current_active_user), -): - try: - global_configs = config.GLOBAL_VISION_LLM_CONFIGS - safe_configs = [] - - if global_configs and len(global_configs) > 0: - safe_configs.append( - { - "id": 0, - "name": "Auto (Fastest)", - "description": "Automatically routes across available vision LLM providers.", - "provider": "AUTO", - "custom_provider": None, - "model_name": "auto", - "api_base": None, - "api_version": None, - "litellm_params": {}, - "is_global": True, - "is_auto_mode": True, - # Auto mode treated as free until per-deployment billing-tier - # surfacing lands; see ``get_vision_llm`` for parity. - "billing_tier": "free", - "is_premium": False, - } - ) - - for cfg in global_configs: - billing_tier = str(cfg.get("billing_tier", "free")).lower() - safe_configs.append( - { - "id": cfg.get("id"), - "name": cfg.get("name"), - "description": cfg.get("description"), - "provider": cfg.get("provider") or cfg.get("litellm_provider"), - "custom_provider": cfg.get("custom_provider"), - "model_name": cfg.get("model_name"), - "api_base": cfg.get("api_base") or None, - "api_version": cfg.get("api_version") or None, - "litellm_params": cfg.get("litellm_params", {}), - "is_global": True, - "billing_tier": billing_tier, - # Mirror chat (``new_llm_config_routes``) so the new-chat - # selector's premium badge logic keys off the same - # field across chat / image / vision tabs. - "is_premium": billing_tier == "premium", - "quota_reserve_tokens": cfg.get("quota_reserve_tokens"), - "input_cost_per_token": cfg.get("input_cost_per_token"), - "output_cost_per_token": cfg.get("output_cost_per_token"), - } - ) - - return safe_configs - except Exception as e: - logger.exception("Failed to fetch global vision LLM configs") - raise HTTPException( - status_code=500, detail=f"Failed to fetch configs: {e!s}" - ) from e - - -# ============================================================================= -# VisionLLMConfig CRUD -# ============================================================================= - - -@router.post("/vision-llm-configs", response_model=VisionLLMConfigRead) -async def create_vision_llm_config( - config_data: VisionLLMConfigCreate, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - try: - await check_permission( - session, - user, - config_data.search_space_id, - Permission.VISION_CONFIGS_CREATE.value, - "You don't have permission to create vision LLM configs in this search space", - ) - - db_config = VisionLLMConfig(**config_data.model_dump(), user_id=user.id) - session.add(db_config) - await session.commit() - await session.refresh(db_config) - return db_config - - except HTTPException: - raise - except Exception as e: - await session.rollback() - logger.exception("Failed to create VisionLLMConfig") - raise HTTPException( - status_code=500, detail=f"Failed to create config: {e!s}" - ) from e - - -@router.get("/vision-llm-configs", response_model=list[VisionLLMConfigRead]) -async def list_vision_llm_configs( - search_space_id: int, - skip: int = 0, - limit: int = 100, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - try: - await check_permission( - session, - user, - search_space_id, - Permission.VISION_CONFIGS_READ.value, - "You don't have permission to view vision LLM configs in this search space", - ) - - result = await session.execute( - select(VisionLLMConfig) - .filter(VisionLLMConfig.search_space_id == search_space_id) - .order_by(VisionLLMConfig.created_at.desc()) - .offset(skip) - .limit(limit) - ) - return result.scalars().all() - - except HTTPException: - raise - except Exception as e: - logger.exception("Failed to list VisionLLMConfigs") - raise HTTPException( - status_code=500, detail=f"Failed to fetch configs: {e!s}" - ) from e - - -@router.get("/vision-llm-configs/{config_id}", response_model=VisionLLMConfigRead) -async def get_vision_llm_config( - config_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - try: - result = await session.execute( - select(VisionLLMConfig).filter(VisionLLMConfig.id == config_id) - ) - db_config = result.scalars().first() - if not db_config: - raise HTTPException(status_code=404, detail="Config not found") - - await check_permission( - session, - user, - db_config.search_space_id, - Permission.VISION_CONFIGS_READ.value, - "You don't have permission to view vision LLM configs in this search space", - ) - return db_config - - except HTTPException: - raise - except Exception as e: - logger.exception("Failed to get VisionLLMConfig") - raise HTTPException( - status_code=500, detail=f"Failed to fetch config: {e!s}" - ) from e - - -@router.put("/vision-llm-configs/{config_id}", response_model=VisionLLMConfigRead) -async def update_vision_llm_config( - config_id: int, - update_data: VisionLLMConfigUpdate, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - try: - result = await session.execute( - select(VisionLLMConfig).filter(VisionLLMConfig.id == config_id) - ) - db_config = result.scalars().first() - if not db_config: - raise HTTPException(status_code=404, detail="Config not found") - - await check_permission( - session, - user, - db_config.search_space_id, - Permission.VISION_CONFIGS_CREATE.value, - "You don't have permission to update vision LLM configs in this search space", - ) - - for key, value in update_data.model_dump(exclude_unset=True).items(): - setattr(db_config, key, value) - - await session.commit() - await session.refresh(db_config) - return db_config - - except HTTPException: - raise - except Exception as e: - await session.rollback() - logger.exception("Failed to update VisionLLMConfig") - raise HTTPException( - status_code=500, detail=f"Failed to update config: {e!s}" - ) from e - - -@router.delete("/vision-llm-configs/{config_id}", response_model=dict) -async def delete_vision_llm_config( - config_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - try: - result = await session.execute( - select(VisionLLMConfig).filter(VisionLLMConfig.id == config_id) - ) - db_config = result.scalars().first() - if not db_config: - raise HTTPException(status_code=404, detail="Config not found") - - await check_permission( - session, - user, - db_config.search_space_id, - Permission.VISION_CONFIGS_DELETE.value, - "You don't have permission to delete vision LLM configs in this search space", - ) - - await session.delete(db_config) - await session.commit() - return { - "message": "Vision LLM config deleted successfully", - "id": config_id, - } - - except HTTPException: - raise - except Exception as e: - await session.rollback() - logger.exception("Failed to delete VisionLLMConfig") - raise HTTPException( - status_code=500, detail=f"Failed to delete config: {e!s}" - ) from e diff --git a/surfsense_backend/app/schemas/__init__.py b/surfsense_backend/app/schemas/__init__.py index 3c4fdfa83..f577397b6 100644 --- a/surfsense_backend/app/schemas/__init__.py +++ b/surfsense_backend/app/schemas/__init__.py @@ -34,11 +34,6 @@ from .folders import ( ) from .google_drive import DriveItem, GoogleDriveIndexingOptions, GoogleDriveIndexRequest from .image_generation import ( - GlobalImageGenConfigRead, - ImageGenerationConfigCreate, - ImageGenerationConfigPublic, - ImageGenerationConfigRead, - ImageGenerationConfigUpdate, ImageGenerationCreate, ImageGenerationListRead, ImageGenerationRead, @@ -74,16 +69,6 @@ from .new_chat import ( ThreadListItem, ThreadListResponse, ) -from .new_llm_config import ( - DefaultSystemInstructionsResponse, - GlobalNewLLMConfigRead, - LLMPreferencesRead, - LLMPreferencesUpdate, - NewLLMConfigCreate, - NewLLMConfigPublic, - NewLLMConfigRead, - NewLLMConfigUpdate, -) from .rbac_schemas import ( InviteAcceptRequest, InviteAcceptResponse, @@ -142,14 +127,6 @@ from .video_presentations import ( VideoPresentationRead, VideoPresentationUpdate, ) -from .vision_llm import ( - GlobalVisionLLMConfigRead, - VisionLLMConfigCreate, - VisionLLMConfigPublic, - VisionLLMConfigRead, - VisionLLMConfigUpdate, -) - __all__ = [ # Folder schemas "BulkDocumentMove", @@ -169,7 +146,6 @@ __all__ = [ "CreditPurchaseHistoryResponse", "CreditPurchaseRead", "CreditStripeStatusResponse", - "DefaultSystemInstructionsResponse", # Document schemas "DocumentBase", "DocumentMove", @@ -192,19 +168,10 @@ __all__ = [ "FolderRead", "FolderReorder", "FolderUpdate", - "GlobalImageGenConfigRead", - "GlobalNewLLMConfigRead", - # Vision LLM Config schemas - "GlobalVisionLLMConfigRead", "GoogleDriveIndexRequest", "GoogleDriveIndexingOptions", # Base schemas "IDModel", - # Image Generation Config schemas - "ImageGenerationConfigCreate", - "ImageGenerationConfigPublic", - "ImageGenerationConfigRead", - "ImageGenerationConfigUpdate", # Image Generation schemas "ImageGenerationCreate", "ImageGenerationListRead", @@ -216,9 +183,6 @@ __all__ = [ "InviteInfoResponse", "InviteRead", "InviteUpdate", - # LLM Preferences schemas - "LLMPreferencesRead", - "LLMPreferencesUpdate", # Log schemas "LogBase", "LogCreate", @@ -255,11 +219,6 @@ __all__ = [ "NewChatThreadRead", "NewChatThreadUpdate", "NewChatThreadWithMessages", - # NewLLMConfig schemas - "NewLLMConfigCreate", - "NewLLMConfigPublic", - "NewLLMConfigRead", - "NewLLMConfigUpdate", "PagePurchaseHistoryResponse", "PagePurchaseRead", "PaginatedResponse", @@ -303,8 +262,4 @@ __all__ = [ "VideoPresentationCreate", "VideoPresentationRead", "VideoPresentationUpdate", - "VisionLLMConfigCreate", - "VisionLLMConfigPublic", - "VisionLLMConfigRead", - "VisionLLMConfigUpdate", ] diff --git a/surfsense_backend/app/schemas/image_generation.py b/surfsense_backend/app/schemas/image_generation.py index 4262b2b3f..83671cc77 100644 --- a/surfsense_backend/app/schemas/image_generation.py +++ b/surfsense_backend/app/schemas/image_generation.py @@ -1,109 +1,10 @@ -""" -Pydantic schemas for Image Generation configs and generation requests. +"""Pydantic schemas for image generation requests/results.""" -ImageGenerationConfig: CRUD schemas for user-created image gen model configs. -ImageGeneration: Schemas for the actual image generation requests/results. -GlobalImageGenConfigRead: Schema for admin-configured YAML configs. -""" - -import uuid from datetime import datetime from typing import Any from pydantic import BaseModel, ConfigDict, Field -from app.db import ImageGenProvider - -# ============================================================================= -# ImageGenerationConfig CRUD Schemas -# ============================================================================= - - -class ImageGenerationConfigBase(BaseModel): - """Base schema with fields for ImageGenerationConfig.""" - - name: str = Field( - ..., max_length=100, description="User-friendly name for the config" - ) - description: str | None = Field( - None, max_length=500, description="Optional description" - ) - provider: ImageGenProvider = Field( - ..., - description="Image generation provider (OpenAI, Azure, Google AI Studio, Vertex AI, Bedrock, Recraft, OpenRouter, Xinference, Nscale)", - ) - custom_provider: str | None = Field( - None, max_length=100, description="Custom provider name" - ) - model_name: str = Field( - ..., max_length=100, description="Model name (e.g., dall-e-3, gpt-image-1)" - ) - api_key: str = Field(..., description="API key for the provider") - api_base: str | None = Field( - None, max_length=500, description="Optional API base URL" - ) - api_version: str | None = Field( - None, - max_length=50, - description="Azure-specific API version (e.g., '2024-02-15-preview')", - ) - litellm_params: dict[str, Any] | None = Field( - default=None, description="Additional LiteLLM parameters" - ) - - -class ImageGenerationConfigCreate(ImageGenerationConfigBase): - """Schema for creating a new ImageGenerationConfig.""" - - search_space_id: int = Field( - ..., description="Search space ID to associate the config with" - ) - - -class ImageGenerationConfigUpdate(BaseModel): - """Schema for updating an existing ImageGenerationConfig. All fields optional.""" - - name: str | None = Field(None, max_length=100) - description: str | None = Field(None, max_length=500) - provider: ImageGenProvider | None = None - custom_provider: str | None = Field(None, max_length=100) - model_name: str | None = Field(None, max_length=100) - api_key: str | None = None - api_base: str | None = Field(None, max_length=500) - api_version: str | None = Field(None, max_length=50) - litellm_params: dict[str, Any] | None = None - - -class ImageGenerationConfigRead(ImageGenerationConfigBase): - """Schema for reading an ImageGenerationConfig (includes id and timestamps).""" - - id: int - created_at: datetime - search_space_id: int - user_id: uuid.UUID - - model_config = ConfigDict(from_attributes=True) - - -class ImageGenerationConfigPublic(BaseModel): - """Public schema that hides the API key (for list views).""" - - id: int - name: str - description: str | None = None - provider: ImageGenProvider - custom_provider: str | None = None - model_name: str - api_base: str | None = None - api_version: str | None = None - litellm_params: dict[str, Any] | None = None - created_at: datetime - search_space_id: int - user_id: uuid.UUID - - model_config = ConfigDict(from_attributes=True) - - # ============================================================================= # ImageGeneration (request/result) Schemas # ============================================================================= @@ -136,12 +37,12 @@ class ImageGenerationCreate(BaseModel): search_space_id: int = Field( ..., description="Search space ID to associate the generation with" ) - image_generation_config_id: int | None = Field( + image_gen_model_id: int | None = Field( None, description=( - "Image generation config ID. " - "0 = Auto mode (router), negative = global YAML config, positive = DB config. " - "If not provided, uses the search space's image_generation_config_id preference." + "Image generation model ID. " + "0 = Auto mode, negative = GLOBAL model, positive = BYOK Model row. " + "If not provided, uses the search space's image_gen_model_id preference." ), ) @@ -157,7 +58,7 @@ class ImageGenerationRead(BaseModel): size: str | None = None style: str | None = None response_format: str | None = None - image_generation_config_id: int | None = None + image_gen_model_id: int | None = None response_data: dict[str, Any] | None = None error_message: str | None = None search_space_id: int @@ -204,57 +105,3 @@ class ImageGenerationListRead(BaseModel): image_count=image_count, ) - -# ============================================================================= -# Global Image Gen Config (from YAML) -# ============================================================================= - - -class GlobalImageGenConfigRead(BaseModel): - """ - Schema for reading global image generation configs from YAML. - Global configs have negative IDs. API key is hidden. - ID 0 is reserved for Auto mode (LiteLLM Router load balancing). - - The ``billing_tier`` field allows the frontend to show a Premium/Free - badge and (more importantly) tells the backend whether to debit the - user's premium credit pool when this config is used. ``"free"`` is - the default for backward compatibility — admins must explicitly opt - a global config into ``"premium"``. - """ - - id: int = Field( - ..., - description="Config ID: 0 for Auto mode, negative for global configs", - ) - name: str - description: str | None = None - provider: str - custom_provider: str | None = None - model_name: str - api_base: str | None = None - api_version: str | None = None - litellm_params: dict[str, Any] | None = None - is_global: bool = True - is_auto_mode: bool = False - billing_tier: str = Field( - default="free", - description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).", - ) - is_premium: bool = Field( - default=False, - description=( - "Convenience boolean derived server-side from " - "``billing_tier == 'premium'``. The new-chat model selector " - "keys its Free/Premium badge off this field for parity with " - "chat (`GlobalLLMConfigRead.is_premium`)." - ), - ) - quota_reserve_micros: int | None = Field( - default=None, - description=( - "Optional override for the reservation amount (in micro-USD) used when " - "this image generation is premium. Falls back to " - "QUOTA_DEFAULT_IMAGE_RESERVE_MICROS when omitted." - ), - ) diff --git a/surfsense_backend/app/schemas/new_llm_config.py b/surfsense_backend/app/schemas/new_llm_config.py deleted file mode 100644 index 2f04a9e66..000000000 --- a/surfsense_backend/app/schemas/new_llm_config.py +++ /dev/null @@ -1,256 +0,0 @@ -""" -Pydantic schemas for the NewLLMConfig API. - -NewLLMConfig combines model settings with prompt configuration: -- LLM provider, model, API key, etc. -- Configurable system instructions -- Citation toggle -""" - -import uuid -from datetime import datetime -from typing import Any - -from pydantic import BaseModel, ConfigDict, Field - -from app.db import LiteLLMProvider - - -class NewLLMConfigBase(BaseModel): - """Base schema with common fields for NewLLMConfig.""" - - name: str = Field( - ..., max_length=100, description="User-friendly name for the configuration" - ) - description: str | None = Field( - None, max_length=500, description="Optional description" - ) - - # Model Configuration - provider: LiteLLMProvider = Field(..., description="LiteLLM provider type") - custom_provider: str | None = Field( - None, max_length=100, description="Custom provider name when provider is CUSTOM" - ) - model_name: str = Field( - ..., max_length=100, description="Model name without provider prefix" - ) - api_key: str = Field(..., description="API key for the provider") - api_base: str | None = Field( - None, max_length=500, description="Optional API base URL" - ) - litellm_params: dict[str, Any] | None = Field( - default=None, description="Additional LiteLLM parameters" - ) - - # Prompt Configuration - system_instructions: str = Field( - default="", - description="Custom system instructions. Empty string uses default SURFSENSE_SYSTEM_INSTRUCTIONS.", - ) - use_default_system_instructions: bool = Field( - default=True, - description="Whether to use default instructions when system_instructions is empty", - ) - citations_enabled: bool = Field( - default=True, - description="Whether to include citation instructions in the system prompt", - ) - - -class NewLLMConfigCreate(NewLLMConfigBase): - """Schema for creating a new NewLLMConfig.""" - - search_space_id: int = Field( - ..., description="Search space ID to associate the config with" - ) - - -class NewLLMConfigUpdate(BaseModel): - """Schema for updating an existing NewLLMConfig. All fields are optional.""" - - name: str | None = Field(None, max_length=100) - description: str | None = Field(None, max_length=500) - - # Model Configuration - provider: LiteLLMProvider | None = None - custom_provider: str | None = Field(None, max_length=100) - model_name: str | None = Field(None, max_length=100) - api_key: str | None = None - api_base: str | None = Field(None, max_length=500) - litellm_params: dict[str, Any] | None = None - - # Prompt Configuration - system_instructions: str | None = None - use_default_system_instructions: bool | None = None - citations_enabled: bool | None = None - - -class NewLLMConfigRead(NewLLMConfigBase): - """Schema for reading a NewLLMConfig (includes id and timestamps).""" - - id: int - created_at: datetime - search_space_id: int - user_id: uuid.UUID - # Capability flag derived at the API boundary (no DB column). Default - # True matches the conservative-allow stance — a BYOK row that the - # route forgot to augment is not pre-judged. The streaming-task - # safety net is the only place a False actually blocks a request. - supports_image_input: bool = Field( - default=True, - description=( - "Whether the BYOK chat config can accept image inputs. Derived " - "at the route boundary from LiteLLM's authoritative model map " - "(``litellm.supports_vision``) — there is no DB column. " - "Default True is the conservative-allow stance for unknown / " - "unmapped models." - ), - ) - - model_config = ConfigDict(from_attributes=True) - - -class NewLLMConfigPublic(BaseModel): - """ - Public schema for NewLLMConfig that hides the API key. - Used when returning configs in list views or to users who shouldn't see keys. - """ - - id: int - name: str - description: str | None = None - - # Model Configuration (no api_key) - provider: LiteLLMProvider - custom_provider: str | None = None - model_name: str - api_base: str | None = None - litellm_params: dict[str, Any] | None = None - - # Prompt Configuration - system_instructions: str - use_default_system_instructions: bool - citations_enabled: bool - - created_at: datetime - search_space_id: int - user_id: uuid.UUID - # Capability flag derived at the API boundary (see NewLLMConfigRead). - supports_image_input: bool = Field( - default=True, - description=( - "Whether the BYOK chat config can accept image inputs. Derived " - "at the route boundary from LiteLLM's authoritative model map. " - "Default True is the conservative-allow stance." - ), - ) - - model_config = ConfigDict(from_attributes=True) - - -class DefaultSystemInstructionsResponse(BaseModel): - """Response schema for getting default system instructions.""" - - default_system_instructions: str = Field( - ..., description="The default SURFSENSE_SYSTEM_INSTRUCTIONS template" - ) - - -class GlobalNewLLMConfigRead(BaseModel): - """ - Schema for reading global LLM configs from YAML. - Global configs have negative IDs and no search_space_id. - API key is hidden for security. - - ID 0 is reserved for Auto mode which uses LiteLLM Router for load balancing. - """ - - id: int = Field( - ..., - description="Config ID: 0 for Auto mode, negative for global configs", - ) - name: str - description: str | None = None - - # Model Configuration (no api_key) - provider: str # String because YAML doesn't enforce enum, "AUTO" for Auto mode - custom_provider: str | None = None - model_name: str - api_base: str | None = None - litellm_params: dict[str, Any] | None = None - - # Prompt Configuration - system_instructions: str = "" - use_default_system_instructions: bool = True - citations_enabled: bool = True - - is_global: bool = True # Always true for global configs - is_auto_mode: bool = False # True only for Auto mode (ID 0) - - billing_tier: str = "free" - is_premium: bool = False - anonymous_enabled: bool = False - seo_enabled: bool = False - seo_slug: str | None = None - seo_title: str | None = None - seo_description: str | None = None - quota_reserve_tokens: int | None = None - supports_image_input: bool = Field( - default=True, - description=( - "Whether the model accepts image inputs (multimodal vision). " - "Derived server-side: OpenRouter dynamic configs use " - "``architecture.input_modalities``; YAML / BYOK use LiteLLM's " - "authoritative model map (``litellm.supports_vision``). The " - "new-chat selector hints with a 'No image' badge when this is " - "False and there are pending image attachments. The streaming " - "task fails fast only when LiteLLM *explicitly* marks a model " - "as text-only — unknown / unmapped models default-allow." - ), - ) - - -# ============================================================================= -# LLM Preferences Schemas (for role assignments) -# ============================================================================= - - -class LLMPreferencesRead(BaseModel): - """Schema for reading LLM preferences (role assignments) for a search space.""" - - agent_llm_id: int | None = Field( - None, description="ID of the LLM config to use for agent/chat tasks" - ) - image_generation_config_id: int | None = Field( - None, description="ID of the image generation config to use" - ) - vision_llm_config_id: int | None = Field( - None, - description="ID of the vision LLM config to use for vision/screenshot analysis", - ) - agent_llm: dict[str, Any] | None = Field( - None, description="Full config for chat model" - ) - image_generation_config: dict[str, Any] | None = Field( - None, description="Full config for image generation" - ) - vision_llm_config: dict[str, Any] | None = Field( - None, description="Full config for vision LLM" - ) - - model_config = ConfigDict(from_attributes=True) - - -class LLMPreferencesUpdate(BaseModel): - """Schema for updating LLM preferences.""" - - agent_llm_id: int | None = Field( - None, description="ID of the LLM config to use for agent/chat tasks" - ) - image_generation_config_id: int | None = Field( - None, description="ID of the image generation config to use" - ) - vision_llm_config_id: int | None = Field( - None, - description="ID of the vision LLM config to use for vision/screenshot analysis", - ) diff --git a/surfsense_backend/app/schemas/vision_llm.py b/surfsense_backend/app/schemas/vision_llm.py deleted file mode 100644 index d0eeaf5c6..000000000 --- a/surfsense_backend/app/schemas/vision_llm.py +++ /dev/null @@ -1,116 +0,0 @@ -import uuid -from datetime import datetime -from typing import Any - -from pydantic import BaseModel, ConfigDict, Field - -from app.db import VisionProvider - - -class VisionLLMConfigBase(BaseModel): - name: str = Field(..., max_length=100) - description: str | None = Field(None, max_length=500) - provider: VisionProvider = Field(...) - custom_provider: str | None = Field(None, max_length=100) - model_name: str = Field(..., max_length=100) - api_key: str = Field(...) - api_base: str | None = Field(None, max_length=500) - api_version: str | None = Field(None, max_length=50) - litellm_params: dict[str, Any] | None = Field(default=None) - - -class VisionLLMConfigCreate(VisionLLMConfigBase): - search_space_id: int = Field(...) - - -class VisionLLMConfigUpdate(BaseModel): - name: str | None = Field(None, max_length=100) - description: str | None = Field(None, max_length=500) - provider: VisionProvider | None = None - custom_provider: str | None = Field(None, max_length=100) - model_name: str | None = Field(None, max_length=100) - api_key: str | None = None - api_base: str | None = Field(None, max_length=500) - api_version: str | None = Field(None, max_length=50) - litellm_params: dict[str, Any] | None = None - - -class VisionLLMConfigRead(VisionLLMConfigBase): - id: int - created_at: datetime - search_space_id: int - user_id: uuid.UUID - - model_config = ConfigDict(from_attributes=True) - - -class VisionLLMConfigPublic(BaseModel): - id: int - name: str - description: str | None = None - provider: VisionProvider - custom_provider: str | None = None - model_name: str - api_base: str | None = None - api_version: str | None = None - litellm_params: dict[str, Any] | None = None - created_at: datetime - search_space_id: int - user_id: uuid.UUID - - model_config = ConfigDict(from_attributes=True) - - -class GlobalVisionLLMConfigRead(BaseModel): - """Schema for reading global vision LLM configs from YAML. - - The ``billing_tier`` field allows the frontend to show a Premium/Free - badge and (more importantly) tells the backend whether to debit the - user's premium credit pool when this config is used. ``"free"`` is - the default for backward compatibility — admins must explicitly opt - a global config into ``"premium"``. - """ - - id: int = Field(...) - name: str - description: str | None = None - provider: str - custom_provider: str | None = None - model_name: str - api_base: str | None = None - api_version: str | None = None - litellm_params: dict[str, Any] | None = None - is_global: bool = True - is_auto_mode: bool = False - billing_tier: str = Field( - default="free", - description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).", - ) - is_premium: bool = Field( - default=False, - description=( - "Convenience boolean derived server-side from " - "``billing_tier == 'premium'``. The new-chat model selector " - "keys its Free/Premium badge off this field for parity with " - "chat (`GlobalLLMConfigRead.is_premium`)." - ), - ) - quota_reserve_tokens: int | None = Field( - default=None, - description=( - "Optional override for the per-call reservation in *tokens* — " - "converted to micro-USD via the model's input/output prices at " - "reservation time. Falls back to QUOTA_DEFAULT_RESERVE_TOKENS." - ), - ) - input_cost_per_token: float | None = Field( - default=None, - description=( - "Optional input price in USD/token. Used by pricing_registration to " - "register custom Azure / OpenRouter aliases with LiteLLM at startup." - ), - ) - output_cost_per_token: float | None = Field( - default=None, - description="Optional output price in USD/token. Pair with input_cost_per_token.", - ) diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py index b4f1bafc9..dfd7c7be3 100644 --- a/surfsense_backend/app/services/auto_model_pin_service.py +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -1,13 +1,13 @@ -"""Resolve and persist Auto (Fastest) model pins per chat thread. +"""Resolve and persist Auto model pins per chat thread. -Auto (Fastest) is represented by ``agent_llm_id == 0``. For chat threads we -resolve that virtual mode to one concrete global LLM config exactly once and +Auto is represented by ``chat_model_id == 0``. For chat threads we +resolve that virtual mode to one concrete global model exactly once and persist the chosen config id on ``new_chat_threads.pinned_llm_config_id`` so subsequent turns are stable. Single-writer invariant: this module is the only writer of ``NewChatThread.pinned_llm_config_id`` (aside from the bulk clear in -``search_spaces_routes`` when a search space's ``agent_llm_id`` changes). +``model_connections_routes`` when a search space's ``chat_model_id`` changes). Therefore a non-NULL value unambiguously means "this thread has an Auto-resolved pin"; no separate source/policy column is needed. """ @@ -33,8 +33,10 @@ from app.services.token_quota_service import TokenQuotaService logger = logging.getLogger(__name__) -AUTO_FASTEST_ID = 0 -AUTO_FASTEST_MODE = "auto_fastest" +AUTO_MODE_ID = 0 +# Stable internal hash namespace for deterministic per-thread selection. +# Do not rename: changing this rebalances Auto's model choice for new pins. +AUTO_PIN_HASH_NAMESPACE = "auto_fastest" _RUNTIME_COOLDOWN_SECONDS = 600 _HEALTHY_TTL_SECONDS = 45 @@ -383,7 +385,7 @@ def _select_pin(eligible: list[dict], thread_id: int) -> tuple[dict, int]: pool = tier_a if tier_a else eligible pool = sorted(pool, key=lambda c: -int(c.get("quality_score") or 0)) top_k = pool[:_QUALITY_TOP_K] - digest = hashlib.sha256(f"{AUTO_FASTEST_MODE}:{thread_id}".encode()).digest() + digest = hashlib.sha256(f"{AUTO_PIN_HASH_NAMESPACE}:{thread_id}".encode()).digest() idx = int.from_bytes(digest[:8], "big") % len(top_k) return top_k[idx], len(top_k) @@ -425,7 +427,7 @@ async def resolve_or_get_pinned_llm_config_id( exclude_config_ids: set[int] | None = None, requires_image_input: bool = False, ) -> AutoPinResolution: - """Resolve Auto (Fastest) to one concrete config id and persist the pin. + """Resolve Auto to one concrete config id and persist the pin. For non-auto selections, this function clears any existing pin and returns the selected id as-is. @@ -457,7 +459,7 @@ async def resolve_or_get_pinned_llm_config_id( ) # Explicit model selected: clear any stale pin. - if selected_llm_config_id != AUTO_FASTEST_ID: + if selected_llm_config_id != AUTO_MODE_ID: if thread.pinned_llm_config_id is not None: thread.pinned_llm_config_id = None await session.commit() diff --git a/surfsense_backend/app/services/billable_calls.py b/surfsense_backend/app/services/billable_calls.py index f21f52e14..15a3c3e55 100644 --- a/surfsense_backend/app/services/billable_calls.py +++ b/surfsense_backend/app/services/billable_calls.py @@ -450,10 +450,10 @@ async def _resolve_agent_billing_for_search_space( Used by Celery tasks (podcast generation, video presentation) to bill the search-space owner's premium credit pool when the chat model is premium. - Resolution rules mirror chat at ``stream_new_chat.py:2294-2351``: + Resolution rules mirror the chat model role resolver: - - Search space not found / no ``agent_llm_id``: raise ``ValueError``. - - **Auto mode** (``id == AUTO_FASTEST_ID == 0``): + - Search space not found / no ``chat_model_id``: raise ``ValueError``. + - **Auto mode** (``id == AUTO_MODE_ID == 0``): * ``thread_id`` is set: delegate to ``resolve_or_get_pinned_llm_config_id`` (the same call chat uses) and recurse into the resolved id. Reuses chat's existing pin if present @@ -469,9 +469,8 @@ async def _resolve_agent_billing_for_search_space( (defaults to ``"free"`` via ``app/config/__init__.py:52`` setdefault), ``base_model = litellm_params.get("base_model") or model_name`` — NOT provider-prefixed, matching chat's cost-map lookup convention. - - **Positive id** (user BYOK ``NewLLMConfig``): always free (matches - ``AgentConfig.from_new_llm_config`` which hard-codes ``billing_tier="free"``); - ``base_model`` from ``litellm_params`` or ``model_name``. + - **Positive id** (user BYOK ``Model``): always free; ``base_model`` from + the model catalog override or the upstream ``model_id``. Note on imports: ``llm_service``, ``auto_model_pin_service``, and ``llm_router_service`` are imported lazily inside the function body to @@ -480,8 +479,9 @@ async def _resolve_agent_billing_for_search_space( ``billable_calls.py``'s module load path. """ from sqlalchemy import select + from sqlalchemy.orm import selectinload - from app.db import NewLLMConfig, SearchSpace + from app.db import Model, SearchSpace result = await session.execute( select(SearchSpace).where(SearchSpace.id == search_space_id) @@ -490,20 +490,20 @@ async def _resolve_agent_billing_for_search_space( if search_space is None: raise ValueError(f"Search space {search_space_id} not found") - agent_llm_id = search_space.agent_llm_id - if agent_llm_id is None: + chat_model_id = search_space.chat_model_id + if chat_model_id is None: raise ValueError( - f"Search space {search_space_id} has no agent_llm_id configured" + f"Search space {search_space_id} has no chat_model_id configured" ) owner_user_id: UUID = search_space.user_id from app.services.auto_model_pin_service import ( - AUTO_FASTEST_ID, + AUTO_MODE_ID, resolve_or_get_pinned_llm_config_id, ) - if agent_llm_id == AUTO_FASTEST_ID: + if chat_model_id == AUTO_MODE_ID: if thread_id is None: return owner_user_id, "free", "auto" try: @@ -512,7 +512,7 @@ async def _resolve_agent_billing_for_search_space( thread_id=thread_id, search_space_id=search_space_id, user_id=str(owner_user_id), - selected_llm_config_id=AUTO_FASTEST_ID, + selected_llm_config_id=AUTO_MODE_ID, ) except ValueError: logger.warning( @@ -523,28 +523,35 @@ async def _resolve_agent_billing_for_search_space( exc_info=True, ) return owner_user_id, "free", "auto" - agent_llm_id = resolution.resolved_llm_config_id + chat_model_id = resolution.resolved_llm_config_id - if agent_llm_id < 0: + if chat_model_id < 0: from app.services.llm_service import get_global_llm_config - cfg = get_global_llm_config(agent_llm_id) or {} + cfg = get_global_llm_config(chat_model_id) or {} billing_tier = str(cfg.get("billing_tier", "free")).lower() litellm_params = cfg.get("litellm_params") or {} base_model = litellm_params.get("base_model") or cfg.get("model_name") or "" return owner_user_id, billing_tier, base_model - nlc_result = await session.execute( - select(NewLLMConfig).where( - NewLLMConfig.id == agent_llm_id, - NewLLMConfig.search_space_id == search_space_id, - ) + model_result = await session.execute( + select(Model) + .options(selectinload(Model.connection)) + .where(Model.id == chat_model_id, Model.enabled.is_(True)) ) - nlc = nlc_result.scalars().first() + model = model_result.scalars().first() base_model = "" - if nlc is not None: - litellm_params = nlc.litellm_params or {} - base_model = litellm_params.get("base_model") or nlc.model_name or "" + if ( + model is not None + and model.connection is not None + and model.connection.enabled + and ( + model.connection.search_space_id in (None, search_space_id) + and model.connection.user_id in (None, owner_user_id) + ) + ): + catalog = model.catalog or {} + base_model = catalog.get("base_model") or model.model_id or "" return owner_user_id, "free", base_model diff --git a/surfsense_backend/app/services/llm_service.py b/surfsense_backend/app/services/llm_service.py index eadb4dbf8..277929e96 100644 --- a/surfsense_backend/app/services/llm_service.py +++ b/surfsense_backend/app/services/llm_service.py @@ -14,7 +14,11 @@ from app.services.auto_model_pin_service import ( auto_model_candidates, choose_auto_model_candidate, ) -from app.services.llm_router_service import AUTO_MODE_ID, ChatLiteLLMRouter, is_auto_mode +from app.services.llm_router_service import ( + AUTO_MODE_ID, + ChatLiteLLMRouter, + is_auto_mode, +) from app.services.model_capabilities import has_capability from app.services.model_resolver import native_connection_from_config, to_litellm from app.services.token_tracking_service import token_tracker @@ -96,26 +100,16 @@ class LLMRole: def get_global_llm_config(llm_config_id: int) -> dict | None: """ Get a global LLM configuration by ID. - Global configs have negative IDs. ID 0 is reserved for Auto mode. + Global configs have negative IDs. Auto mode (ID 0) is resolved through the + model-candidate pipeline, not this legacy config lookup. Args: - llm_config_id: The ID of the global config (should be negative or 0 for Auto) + llm_config_id: The ID of the global config (must be negative) Returns: dict: Global config dictionary or None if not found """ - # Auto mode (ID 0) is handled separately via the router - if llm_config_id == AUTO_MODE_ID: - return { - "id": AUTO_MODE_ID, - "name": "Auto (Fastest)", - "description": "Automatically routes requests across available LLM providers for optimal performance and rate limit handling", - "provider": "AUTO", - "model_name": "auto", - "is_auto_mode": True, - } - - if llm_config_id > 0: + if llm_config_id >= 0: return None for cfg in config.GLOBAL_LLM_CONFIGS: diff --git a/surfsense_backend/app/services/model_list_service.py b/surfsense_backend/app/services/model_list_service.py index 0761d7e4f..ffb430756 100644 --- a/surfsense_backend/app/services/model_list_service.py +++ b/surfsense_backend/app/services/model_list_service.py @@ -24,7 +24,7 @@ CACHE_TTL_SECONDS = 86400 # 24 hours _cache: list[dict] | None = None _cache_timestamp: float = 0 -# Maps OpenRouter provider slug → our LiteLLMProvider enum value. +# Maps OpenRouter provider slug to native LiteLLM provider prefixes. # Only providers where the model-name part (after the slash) can be # used directly with the native provider's litellm prefix are listed. # diff --git a/surfsense_backend/app/services/openrouter_integration_service.py b/surfsense_backend/app/services/openrouter_integration_service.py index fbb70eb5a..8f4c4cb5f 100644 --- a/surfsense_backend/app/services/openrouter_integration_service.py +++ b/surfsense_backend/app/services/openrouter_integration_service.py @@ -281,7 +281,7 @@ def _generate_configs( OpenRouter's own ``openrouter/free`` meta-router is filtered out upstream via ``_EXCLUDED_MODEL_IDS``; we don't expose a redundant auto-select layer - because our own Auto (Fastest) pin + 24 h refresh + repair logic already + because our own Auto pin + 24 h refresh + repair logic already cover the catalogue-churn case. """ id_offset: int = settings.get("id_offset", -10000) @@ -346,7 +346,7 @@ def _generate_configs( # ``"No endpoints found that support image input"``. "supports_image_input": bool(normalized.get("supports_image_input")), _OPENROUTER_DYNAMIC_MARKER: True, - # Auto (Fastest) ranking metadata. ``quality_score`` is initialised + # Auto ranking metadata. ``quality_score`` is initialised # to the static score and gets re-blended with health on the next # ``_enrich_health`` pass (synchronous on refresh, deferred on cold # start so startup latency is unchanged). @@ -361,11 +361,7 @@ def _generate_configs( return configs -# ID-offset bands used to keep dynamic OpenRouter configs in their own -# namespace per surface. Image / vision get separate bands so a single -# Postgres-INTEGER cfg ID is unambiguous about which selector it belongs to. _OPENROUTER_IMAGE_ID_OFFSET_DEFAULT = -20000 -_OPENROUTER_VISION_ID_OFFSET_DEFAULT = -30000 def _generate_image_gen_configs( @@ -431,89 +427,6 @@ def _generate_image_gen_configs( return configs -def _generate_vision_llm_configs( - raw_models: list[dict], settings: dict[str, Any] -) -> list[dict]: - """Convert OpenRouter vision-capable LLMs into global vision-LLM config - dicts (matches the YAML shape consumed by ``vision_llm_routes``). - - Filter: - - architecture.input_modalities contains "image" - - architecture.output_modalities contains "text" - - compatible provider (excluded slugs blocked) - - allowed model id (excluded list blocked) - - Vision-LLM is invoked from the indexer (image extraction during - document upload) via ``langchain_litellm.ChatLiteLLM.ainvoke``, so - the chat-only ``_supports_tool_calling`` and ``_has_sufficient_context`` - filters do not apply: a small-context vision model that doesn't - advertise tool-calling is still perfectly viable for "describe this - image" prompts. - """ - id_offset: int = int( - settings.get("vision_id_offset") or _OPENROUTER_VISION_ID_OFFSET_DEFAULT - ) - api_key: str = settings.get("api_key", "") - rpm: int = settings.get("rpm", 200) - tpm: int = settings.get("tpm", 1_000_000) - free_rpm: int = settings.get("free_rpm", 20) - free_tpm: int = settings.get("free_tpm", 100_000) - quota_reserve_tokens: int = settings.get("quota_reserve_tokens", 4000) - litellm_params: dict = settings.get("litellm_params") or {} - - vision_models = [ - m - for m in raw_models - if supports_image_input(m) - and _shared_is_compatible_provider(m) - and _shared_is_allowed_model(m) - and "/" in m.get("id", "") - ] - - configs: list[dict] = [] - taken: set[int] = set() - for model in vision_models: - model_id: str = model["id"] - name: str = model.get("name", model_id) - tier = _openrouter_tier(model) - pricing = model.get("pricing") or {} - - # Capture per-token prices so ``pricing_registration`` can - # register them with LiteLLM at startup (and so the cost - # estimator in ``estimate_call_reserve_micros`` can resolve - # them at reserve time). - try: - input_cost = float(pricing.get("prompt", 0) or 0) - except (TypeError, ValueError): - input_cost = 0.0 - try: - output_cost = float(pricing.get("completion", 0) or 0) - except (TypeError, ValueError): - output_cost = 0.0 - - cfg: dict[str, Any] = { - "id": _stable_config_id(model_id, id_offset, taken), - "name": name, - "description": f"{name} via OpenRouter (vision)", - "provider": "openrouter", - "model_name": model_id, - "api_key": api_key, - "api_base": "https://openrouter.ai/api/v1", - "api_version": None, - "rpm": free_rpm if tier == "free" else rpm, - "tpm": free_tpm if tier == "free" else tpm, - "litellm_params": dict(litellm_params), - "billing_tier": tier, - "quota_reserve_tokens": quota_reserve_tokens, - "input_cost_per_token": input_cost or None, - "output_cost_per_token": output_cost or None, - _OPENROUTER_DYNAMIC_MARKER: True, - } - configs.append(cfg) - - return configs - - class OpenRouterIntegrationService: """Singleton that manages the dynamic OpenRouter model catalogue.""" @@ -724,7 +637,7 @@ class OpenRouterIntegrationService: return counts # ------------------------------------------------------------------ - # Auto (Fastest) health enrichment + # Auto health enrichment # ------------------------------------------------------------------ async def _enrich_health_safely( diff --git a/surfsense_backend/app/services/pricing_registration.py b/surfsense_backend/app/services/pricing_registration.py index 9e4e3b552..7343df737 100644 --- a/surfsense_backend/app/services/pricing_registration.py +++ b/surfsense_backend/app/services/pricing_registration.py @@ -154,10 +154,8 @@ def _register_chat_shape_configs( input_cost = _safe_float(entry.get("prompt")) output_cost = _safe_float(entry.get("completion")) else: - # Vision configs from ``_generate_vision_llm_configs`` - # carry their pricing inline because the OpenRouter - # raw-pricing cache is keyed by chat-catalogue model_id; - # vision flows pick up the inline values here. + # Some dynamically materialized configs can carry pricing + # inline when the raw OpenRouter cache has no matching entry. input_cost = _safe_float(cfg.get("input_cost_per_token")) output_cost = _safe_float(cfg.get("output_cost_per_token")) if input_cost == 0.0 and output_cost == 0.0: diff --git a/surfsense_backend/app/services/quality_score.py b/surfsense_backend/app/services/quality_score.py index 9cc9c21ac..737dd7c2f 100644 --- a/surfsense_backend/app/services/quality_score.py +++ b/surfsense_backend/app/services/quality_score.py @@ -1,4 +1,4 @@ -"""Pure-function quality scoring for Auto (Fastest) model selection. +"""Pure-function quality scoring for Auto model selection. This module is import-free of any service / request-path dependencies. All numbers are computed once during the OpenRouter refresh tick (or YAML load) diff --git a/surfsense_backend/app/services/vision_llm_router_service.py b/surfsense_backend/app/services/vision_llm_router_service.py deleted file mode 100644 index 0ff716324..000000000 --- a/surfsense_backend/app/services/vision_llm_router_service.py +++ /dev/null @@ -1,160 +0,0 @@ -import logging -from typing import Any - -from litellm import Router - -from app.services.model_resolver import native_connection_from_config, to_litellm - -logger = logging.getLogger(__name__) - -VISION_AUTO_MODE_ID = 0 - -class VisionLLMRouterService: - _instance = None - _router: Router | None = None - _model_list: list[dict] = [] - _router_settings: dict = {} - _initialized: bool = False - - def __new__(cls): - if cls._instance is None: - cls._instance = super().__new__(cls) - return cls._instance - - @classmethod - def get_instance(cls) -> "VisionLLMRouterService": - if cls._instance is None: - cls._instance = cls() - return cls._instance - - @classmethod - def initialize( - cls, - global_configs: list[dict], - router_settings: dict | None = None, - ) -> None: - instance = cls.get_instance() - - if instance._initialized: - logger.debug("Vision LLM Router already initialized, skipping") - return - - model_list = [] - for config in global_configs: - deployment = cls._config_to_deployment(config) - if deployment: - model_list.append(deployment) - - if not model_list: - logger.warning( - "No valid vision LLM configs found for router initialization" - ) - return - - instance._model_list = model_list - instance._router_settings = router_settings or {} - - default_settings = { - "routing_strategy": "usage-based-routing", - "num_retries": 3, - "allowed_fails": 3, - "cooldown_time": 60, - "retry_after": 5, - } - - final_settings = {**default_settings, **instance._router_settings} - - try: - instance._router = Router( - model_list=model_list, - routing_strategy=final_settings.get( - "routing_strategy", "usage-based-routing" - ), - num_retries=final_settings.get("num_retries", 3), - allowed_fails=final_settings.get("allowed_fails", 3), - cooldown_time=final_settings.get("cooldown_time", 60), - set_verbose=False, - ) - instance._initialized = True - logger.info( - "Vision LLM Router initialized with %d deployments, strategy: %s", - len(model_list), - final_settings.get("routing_strategy"), - ) - except Exception as e: - logger.error(f"Failed to initialize Vision LLM Router: {e}") - instance._router = None - - @classmethod - def _config_to_deployment(cls, config: dict) -> dict | None: - try: - if not config.get("model_name") or not config.get("api_key"): - return None - - model_string, resolved_kwargs = to_litellm( - native_connection_from_config(config), - config["model_name"], - ) - litellm_params: dict[str, Any] = {"model": model_string, **resolved_kwargs} - - deployment: dict[str, Any] = { - "model_name": "auto", - "litellm_params": litellm_params, - } - - if config.get("rpm"): - deployment["rpm"] = config["rpm"] - if config.get("tpm"): - deployment["tpm"] = config["tpm"] - - return deployment - - except Exception as e: - logger.warning(f"Failed to convert vision config to deployment: {e}") - return None - - @classmethod - def get_router(cls) -> Router | None: - instance = cls.get_instance() - return instance._router - - @classmethod - def is_initialized(cls) -> bool: - instance = cls.get_instance() - return instance._initialized and instance._router is not None - - @classmethod - def get_model_count(cls) -> int: - instance = cls.get_instance() - return len(instance._model_list) - - -def is_vision_auto_mode(config_id: int | None) -> bool: - return config_id == VISION_AUTO_MODE_ID - - -def build_vision_model_string( - litellm_provider: str, model_name: str, custom_provider: str | None -) -> str: - if custom_provider: - return f"{custom_provider}/{model_name}" - return f"{litellm_provider}/{model_name}" - - -def get_global_vision_llm_config(config_id: int) -> dict | None: - from app.config import config - - if config_id == VISION_AUTO_MODE_ID: - return { - "id": VISION_AUTO_MODE_ID, - "name": "Auto (Fastest)", - "provider": "AUTO", - "model_name": "auto", - "is_auto_mode": True, - } - if config_id > 0: - return None - for cfg in config.GLOBAL_VISION_LLM_CONFIGS: - if cfg.get("id") == config_id: - return cfg - return None diff --git a/surfsense_backend/app/services/vision_model_list_service.py b/surfsense_backend/app/services/vision_model_list_service.py deleted file mode 100644 index 6eae8c455..000000000 --- a/surfsense_backend/app/services/vision_model_list_service.py +++ /dev/null @@ -1,134 +0,0 @@ -""" -Service for fetching and caching the vision-capable model list. - -Reuses the same OpenRouter public API and local fallback as the LLM model -list service, but filters for models that accept image input. -""" - -import json -import logging -import time -from pathlib import Path - -import httpx - -logger = logging.getLogger(__name__) - -OPENROUTER_API_URL = "https://openrouter.ai/api/v1/models" -FALLBACK_FILE = ( - Path(__file__).parent.parent / "config" / "vision_model_list_fallback.json" -) -CACHE_TTL_SECONDS = 86400 # 24 hours - -_cache: list[dict] | None = None -_cache_timestamp: float = 0 - -OPENROUTER_SLUG_TO_VISION_PROVIDER: dict[str, str] = { - "openai": "OPENAI", - "anthropic": "ANTHROPIC", - "google": "GOOGLE", - "mistralai": "MISTRAL", - "x-ai": "XAI", -} - - -def _format_context_length(length: int | None) -> str | None: - if not length: - return None - if length >= 1_000_000: - return f"{length / 1_000_000:g}M" - if length >= 1_000: - return f"{length / 1_000:g}K" - return str(length) - - -async def _fetch_from_openrouter() -> list[dict] | None: - try: - async with httpx.AsyncClient(timeout=15) as client: - response = await client.get(OPENROUTER_API_URL) - response.raise_for_status() - data = response.json() - return data.get("data", []) - except Exception as e: - logger.warning("Failed to fetch from OpenRouter API for vision models: %s", e) - return None - - -def _load_fallback() -> list[dict]: - try: - with open(FALLBACK_FILE, encoding="utf-8") as f: - return json.load(f) - except Exception as e: - logger.error("Failed to load vision model fallback list: %s", e) - return [] - - -def _is_vision_model(model: dict) -> bool: - """Return True if the model accepts image input and outputs text.""" - arch = model.get("architecture", {}) - input_mods = arch.get("input_modalities", []) - output_mods = arch.get("output_modalities", []) - return "image" in input_mods and "text" in output_mods - - -def _process_vision_models(raw_models: list[dict]) -> list[dict]: - processed: list[dict] = [] - - for model in raw_models: - model_id: str = model.get("id", "") - name: str = model.get("name", "") - context_length = model.get("context_length") - - if "/" not in model_id: - continue - - if not _is_vision_model(model): - continue - - provider_slug, model_name = model_id.split("/", 1) - context_window = _format_context_length(context_length) - - processed.append( - { - "value": model_id, - "label": name, - "provider": "OPENROUTER", - "context_window": context_window, - } - ) - - direct_provider = OPENROUTER_SLUG_TO_VISION_PROVIDER.get(provider_slug) - if direct_provider: - if direct_provider == "GOOGLE" and not model_name.startswith("gemini-"): - continue - - processed.append( - { - "value": model_name, - "label": name, - "provider": direct_provider, - "context_window": context_window, - } - ) - - return processed - - -async def get_vision_model_list() -> list[dict]: - global _cache, _cache_timestamp - - if _cache is not None and (time.time() - _cache_timestamp) < CACHE_TTL_SECONDS: - return _cache - - raw_models = await _fetch_from_openrouter() - - if raw_models is None: - logger.info("Using fallback vision model list") - return _load_fallback() - - processed = _process_vision_models(raw_models) - - _cache = processed - _cache_timestamp = time.time() - - return processed diff --git a/surfsense_backend/scripts/verify_chat_image_capability.py b/surfsense_backend/scripts/verify_chat_image_capability.py index 6e711f99a..e6a535711 100644 --- a/surfsense_backend/scripts/verify_chat_image_capability.py +++ b/surfsense_backend/scripts/verify_chat_image_capability.py @@ -330,31 +330,6 @@ async def probe_chat_configs(report: Report, *, live: bool) -> None: report.add(result) -async def probe_vision_configs(report: Report, *, live: bool) -> None: - print("\n[vision configs from global_vision_llm_configs (YAML-static)]") - for cfg in config.GLOBAL_VISION_LLM_CONFIGS: - if _is_or_dynamic(cfg): - continue - result = ProbeResult( - label=str(cfg.get("name") or cfg.get("model_name")), - surface="vision", - config_id=cfg.get("id"), - ) - # For vision configs, capability is implied — they're in the - # dedicated vision pool. Run the same resolver to flag any - # surprise disagreement. - cap_ok, cap_note = _probe_chat_capability(cfg) - result.capability_ok = cap_ok - result.capability_note = cap_note - if live: - t0 = time.perf_counter() - ok, note = await _live_chat_image_call(cfg) - result.live_ok = ok - result.live_note = note - result.duration_s = time.perf_counter() - t0 - report.add(result) - - async def probe_image_gen_configs(report: Report, *, live: bool) -> None: print( "\n[image generation configs from global_image_generation_configs (YAML-static)]" @@ -380,7 +355,7 @@ async def probe_image_gen_configs(report: Report, *, live: bool) -> None: async def probe_openrouter_catalog(report: Report, *, live: bool) -> None: - """Sample one chat (vision-capable), one vision, one image-gen model + """Sample chat/vision-capable and image-gen models from the live OpenRouter catalogue. Doesn't iterate the full pool (would be hundreds of probes); just validates the integration end- to-end on a representative model from each surface.""" @@ -405,9 +380,6 @@ async def probe_openrouter_catalog(report: Report, *, live: bool) -> None: for c in config.GLOBAL_LLM_CONFIGS if c.get("provider") == "OPENROUTER" and c.get("supports_image_input") ] - or_vision = [ - c for c in config.GLOBAL_VISION_LLM_CONFIGS if c.get("provider") == "OPENROUTER" - ] or_image_gen = [ c for c in config.GLOBAL_IMAGE_GEN_CONFIGS if c.get("provider") == "OPENROUTER" ] @@ -427,11 +399,6 @@ async def probe_openrouter_catalog(report: Report, *, live: bool) -> None: ("or-chat", _pick_first(or_chat, "anthropic/claude")), ("or-chat", _pick_first(or_chat, "google/gemini-2.5-flash")), ] - vision_picks = [ - ("or-vision", _pick_first(or_vision, "openai/gpt-4o")), - ("or-vision", _pick_first(or_vision, "anthropic/claude")), - ("or-vision", _pick_first(or_vision, "google/gemini-2.5-flash")), - ] image_picks = [ ("or-image", _pick_first(or_image_gen, "google/gemini-2.5-flash-image")), # OpenRouter publishes OpenAI image models as ``openai/gpt-5-image*`` @@ -441,11 +408,11 @@ async def probe_openrouter_catalog(report: Report, *, live: bool) -> None: ] print( - f" catalog: chat={len(or_chat)} vision={len(or_vision)} image_gen={len(or_image_gen)} " + f" catalog: chat_vision={len(or_chat)} image_gen={len(or_image_gen)} " f"(service initialized={service.is_initialized() if hasattr(service, 'is_initialized') else 'n/a'})" ) - for surface, picked in chat_picks + vision_picks + image_picks: + for surface, picked in chat_picks + image_picks: if not picked: report.add( ProbeResult( @@ -486,7 +453,6 @@ async def probe_openrouter_catalog(report: Report, *, live: bool) -> None: async def main(args: argparse.Namespace) -> int: print("Loaded global configs:") print(f" chat: {len(config.GLOBAL_LLM_CONFIGS)} entries") - print(f" vision: {len(config.GLOBAL_VISION_LLM_CONFIGS)} entries") print(f" image-gen: {len(config.GLOBAL_IMAGE_GEN_CONFIGS)} entries") print(f" OR settings present: {bool(config.OPENROUTER_INTEGRATION_SETTINGS)}") @@ -507,8 +473,6 @@ async def main(args: argparse.Namespace) -> int: report = Report() if not args.skip_chat: await probe_chat_configs(report, live=args.live) - if not args.skip_vision: - await probe_vision_configs(report, live=args.live) if not args.skip_image_gen: await probe_image_gen_configs(report, live=args.live) if not args.skip_openrouter: @@ -528,7 +492,6 @@ def _parse_args() -> argparse.Namespace: ) parser.set_defaults(live=True) parser.add_argument("--skip-chat", action="store_true") - parser.add_argument("--skip-vision", action="store_true") parser.add_argument("--skip-image-gen", action="store_true") parser.add_argument("--skip-openrouter", action="store_true") return parser.parse_args() diff --git a/surfsense_backend/tests/unit/automations/actions/builtin/agent_task/test_dependencies.py b/surfsense_backend/tests/unit/automations/actions/builtin/agent_task/test_dependencies.py index 79da12933..f5709e517 100644 --- a/surfsense_backend/tests/unit/automations/actions/builtin/agent_task/test_dependencies.py +++ b/surfsense_backend/tests/unit/automations/actions/builtin/agent_task/test_dependencies.py @@ -1,6 +1,6 @@ """Lock the runtime model-policy backstop in ``build_dependencies``. -Automations resolve their LLM from the *captured* ``agent_llm_id`` snapshot (so +Automations resolve their LLM from the *captured* ``chat_model_id`` snapshot (so runs are insulated from later chat/search-space model changes), and the model policy is re-checked at run time so a captured model that is no longer billable fails the run clearly. When no snapshot is present, resolution falls back to the @@ -45,10 +45,10 @@ def patched_side_effects(monkeypatch: pytest.MonkeyPatch): return None -async def test_build_dependencies_resolves_captured_agent_llm_id( +async def test_build_dependencies_resolves_captured_chat_model_id( monkeypatch: pytest.MonkeyPatch, patched_side_effects ) -> None: - """The bundle loads with the *captured* ``agent_llm_id``, not the live search space.""" + """The bundle loads with the *captured* ``chat_model_id``, not the live search space.""" captured: dict[str, Any] = {} async def _fake_load(_session, *, config_id, search_space_id): @@ -67,13 +67,13 @@ async def test_build_dependencies_resolves_captured_agent_llm_id( lambda _ss: pytest.fail("search-space policy should not run on captured path"), ) - search_space = SimpleNamespace(agent_llm_id=-99) + search_space = SimpleNamespace(chat_model_id=-99) result = await build_dependencies( session=_FakeSession(search_space), search_space_id=42, - agent_llm_id=-7, - image_generation_config_id=5, - vision_llm_config_id=-1, + chat_model_id=-7, + image_gen_model_id=5, + vision_model_id=-1, ) assert captured == {"config_id": -7, "search_space_id": 42} @@ -98,17 +98,17 @@ async def test_build_dependencies_validates_captured_ids( monkeypatch.setattr(deps_mod, "load_llm_bundle", _fake_load) await build_dependencies( - session=_FakeSession(SimpleNamespace(agent_llm_id=0)), + session=_FakeSession(SimpleNamespace(chat_model_id=0)), search_space_id=42, - agent_llm_id=-7, - image_generation_config_id=5, - vision_llm_config_id=-1, + chat_model_id=-7, + image_gen_model_id=5, + vision_model_id=-1, ) assert seen == { - "agent_llm_id": -7, - "image_generation_config_id": 5, - "vision_llm_config_id": -1, + "chat_model_id": -7, + "image_gen_model_id": 5, + "vision_model_id": -1, } @@ -119,7 +119,7 @@ async def test_build_dependencies_raises_on_captured_policy_violation( def _raise(**_kw): raise AutomationModelPolicyError( - [{"kind": "image", "config_id": -2, "reason": "free model"}] + [{"kind": "image", "model_id": -2, "reason": "free model"}] ) monkeypatch.setattr(deps_mod, "assert_models_billable", _raise) @@ -131,11 +131,11 @@ async def test_build_dependencies_raises_on_captured_policy_violation( with pytest.raises(DependencyError): await build_dependencies( - session=_FakeSession(SimpleNamespace(agent_llm_id=-7)), + session=_FakeSession(SimpleNamespace(chat_model_id=-7)), search_space_id=42, - agent_llm_id=-7, - image_generation_config_id=-2, - vision_llm_config_id=-1, + chat_model_id=-7, + image_gen_model_id=-2, + vision_model_id=-1, ) @@ -157,7 +157,7 @@ async def test_build_dependencies_falls_back_to_search_space( lambda **_kw: pytest.fail("captured policy should not run on fallback path"), ) - search_space = SimpleNamespace(agent_llm_id=-7) + search_space = SimpleNamespace(chat_model_id=-7) result = await build_dependencies( session=_FakeSession(search_space), search_space_id=42 ) diff --git a/surfsense_backend/tests/unit/automations/runtime/test_executor_action_ctx.py b/surfsense_backend/tests/unit/automations/runtime/test_executor_action_ctx.py index d7e3c4a0c..c89624fbf 100644 --- a/surfsense_backend/tests/unit/automations/runtime/test_executor_action_ctx.py +++ b/surfsense_backend/tests/unit/automations/runtime/test_executor_action_ctx.py @@ -28,9 +28,9 @@ def _run() -> SimpleNamespace: def test_build_action_ctx_propagates_captured_models() -> None: """``definition.models`` flows onto the ActionContext model fields.""" models = AutomationModels( - agent_llm_id=-1, - image_generation_config_id=5, - vision_llm_config_id=-1, + chat_model_id=-1, + image_gen_model_id=5, + vision_model_id=-1, ) ctx = _build_action_ctx( cast(AsyncSession, None), @@ -40,9 +40,9 @@ def test_build_action_ctx_propagates_captured_models() -> None: ) assert ctx.search_space_id == 42 - assert ctx.agent_llm_id == -1 - assert ctx.image_generation_config_id == 5 - assert ctx.vision_llm_config_id == -1 + assert ctx.chat_model_id == -1 + assert ctx.image_gen_model_id == 5 + assert ctx.vision_model_id == -1 def test_build_action_ctx_none_models_leaves_fields_none() -> None: @@ -54,6 +54,6 @@ def test_build_action_ctx_none_models_leaves_fields_none() -> None: None, ) - assert ctx.agent_llm_id is None - assert ctx.image_generation_config_id is None - assert ctx.vision_llm_config_id is None + assert ctx.chat_model_id is None + assert ctx.image_gen_model_id is None + assert ctx.vision_model_id is None diff --git a/surfsense_backend/tests/unit/automations/schemas/definition/test_envelope.py b/surfsense_backend/tests/unit/automations/schemas/definition/test_envelope.py index 25e193ffa..dc7221b11 100644 --- a/surfsense_backend/tests/unit/automations/schemas/definition/test_envelope.py +++ b/surfsense_backend/tests/unit/automations/schemas/definition/test_envelope.py @@ -40,24 +40,24 @@ def test_automation_definition_models_round_trip() -> None: name="Daily digest", plan=[PlanStep(step_id="s1", action="agent_task")], models=AutomationModels( - agent_llm_id=-1, - image_generation_config_id=5, - vision_llm_config_id=-1, + chat_model_id=-1, + image_gen_model_id=5, + vision_model_id=-1, ), ) dumped = definition.model_dump(mode="json", by_alias=True) assert dumped["models"] == { - "agent_llm_id": -1, - "image_generation_config_id": 5, - "vision_llm_config_id": -1, + "chat_model_id": -1, + "image_gen_model_id": 5, + "vision_model_id": -1, } restored = AutomationDefinition.model_validate(dumped) assert restored.models is not None - assert restored.models.agent_llm_id == -1 - assert restored.models.image_generation_config_id == 5 - assert restored.models.vision_llm_config_id == -1 + assert restored.models.chat_model_id == -1 + assert restored.models.image_gen_model_id == 5 + assert restored.models.vision_model_id == -1 def test_automation_definition_rejects_unknown_top_level_field() -> None: diff --git a/surfsense_backend/tests/unit/automations/services/test_automation_service_policy.py b/surfsense_backend/tests/unit/automations/services/test_automation_service_policy.py index 0bbff39dc..c97dec6a2 100644 --- a/surfsense_backend/tests/unit/automations/services/test_automation_service_policy.py +++ b/surfsense_backend/tests/unit/automations/services/test_automation_service_policy.py @@ -64,12 +64,12 @@ async def test_assert_models_billable_raises_422_on_violation( def _raise(_ss): raise AutomationModelPolicyError( - [{"kind": "llm", "config_id": 0, "reason": "Auto mode"}] + [{"kind": "llm", "model_id": 0, "reason": "Auto mode"}] ) monkeypatch.setattr(automation_mod, "assert_automation_models_billable", _raise) - service = _service(SimpleNamespace(agent_llm_id=0)) + service = _service(SimpleNamespace(chat_model_id=0)) with pytest.raises(HTTPException) as exc_info: await service._assert_models_billable(1) @@ -99,7 +99,7 @@ async def test_assert_models_billable_returns_search_space_when_ok( automation_mod, "assert_automation_models_billable", lambda _ss: None ) - search_space = SimpleNamespace(agent_llm_id=-1) + search_space = SimpleNamespace(chat_model_id=-1) service = _service(search_space) assert await service._assert_models_billable(1) is search_space @@ -123,9 +123,9 @@ async def test_create_injects_captured_models_from_search_space( monkeypatch.setattr(AutomationService, "_get_with_triggers_or_raise", _return_added) search_space = SimpleNamespace( - agent_llm_id=-1, - image_generation_config_id=5, - vision_llm_config_id=-1, + chat_model_id=-1, + image_gen_model_id=5, + vision_model_id=-1, ) service = _service(search_space) payload = AutomationCreate( @@ -137,9 +137,9 @@ async def test_create_injects_captured_models_from_search_space( automation = await service.create(payload) assert automation.definition["models"] == { - "agent_llm_id": -1, - "image_generation_config_id": 5, - "vision_llm_config_id": -1, + "chat_model_id": -1, + "image_gen_model_id": 5, + "vision_model_id": -1, } @@ -162,9 +162,9 @@ async def test_create_treats_unset_prefs_as_auto_zero( monkeypatch.setattr(AutomationService, "_get_with_triggers_or_raise", _return_added) search_space = SimpleNamespace( - agent_llm_id=None, - image_generation_config_id=None, - vision_llm_config_id=None, + chat_model_id=None, + image_gen_model_id=None, + vision_model_id=None, ) service = _service(search_space) payload = AutomationCreate(search_space_id=1, name="A", definition=_definition()) @@ -172,9 +172,9 @@ async def test_create_treats_unset_prefs_as_auto_zero( automation = await service.create(payload) assert automation.definition["models"] == { - "agent_llm_id": 0, - "image_generation_config_id": 0, - "vision_llm_config_id": 0, + "chat_model_id": 0, + "image_gen_model_id": 0, + "vision_model_id": 0, } @@ -195,11 +195,11 @@ async def test_create_honors_selected_models_when_provided( ) validated: dict[str, Any] = {} - def _assert_ok(*, agent_llm_id, image_generation_config_id, vision_llm_config_id): + def _assert_ok(*, chat_model_id, image_gen_model_id, vision_model_id): validated["ids"] = ( - agent_llm_id, - image_generation_config_id, - vision_llm_config_id, + chat_model_id, + image_gen_model_id, + vision_model_id, ) monkeypatch.setattr(automation_mod, "assert_models_billable", _assert_ok) @@ -213,15 +213,15 @@ async def test_create_honors_selected_models_when_provided( monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize) monkeypatch.setattr(AutomationService, "_get_with_triggers_or_raise", _return_added) - service = _service(SimpleNamespace(agent_llm_id=-99)) + service = _service(SimpleNamespace(chat_model_id=-99)) payload = AutomationCreate( search_space_id=1, name="A", definition=_definition( models=AutomationModels( - agent_llm_id=-1, - image_generation_config_id=7, - vision_llm_config_id=-2, + chat_model_id=-1, + image_gen_model_id=7, + vision_model_id=-2, ) ), ) @@ -230,9 +230,9 @@ async def test_create_honors_selected_models_when_provided( assert validated["ids"] == (-1, 7, -2) assert automation.definition["models"] == { - "agent_llm_id": -1, - "image_generation_config_id": 7, - "vision_llm_config_id": -2, + "chat_model_id": -1, + "image_gen_model_id": 7, + "vision_model_id": -2, } @@ -241,9 +241,9 @@ async def test_create_rejects_unbillable_selected_models( ) -> None: """A non-billable explicit selection maps the policy error to HTTP 422.""" - def _raise(*, agent_llm_id, image_generation_config_id, vision_llm_config_id): + def _raise(*, chat_model_id, image_gen_model_id, vision_model_id): raise AutomationModelPolicyError( - [{"kind": "llm", "config_id": -3, "reason": "free model"}] + [{"kind": "llm", "model_id": -3, "reason": "free model"}] ) monkeypatch.setattr(automation_mod, "assert_models_billable", _raise) @@ -253,15 +253,15 @@ async def test_create_rejects_unbillable_selected_models( monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize) - service = _service(SimpleNamespace(agent_llm_id=-3)) + service = _service(SimpleNamespace(chat_model_id=-3)) payload = AutomationCreate( search_space_id=1, name="A", definition=_definition( models=AutomationModels( - agent_llm_id=-3, - image_generation_config_id=7, - vision_llm_config_id=-2, + chat_model_id=-3, + image_gen_model_id=7, + vision_model_id=-2, ) ), ) @@ -277,9 +277,9 @@ async def test_update_preserves_captured_models( ) -> None: """A definition edit carries over the previously captured ``models``.""" captured = { - "agent_llm_id": -1, - "image_generation_config_id": 5, - "vision_llm_config_id": -1, + "chat_model_id": -1, + "image_gen_model_id": 5, + "vision_model_id": -1, } existing = SimpleNamespace( search_space_id=1, @@ -318,20 +318,20 @@ async def test_update_honors_changed_models_when_valid( "name": "A", "plan": [], "models": { - "agent_llm_id": -1, - "image_generation_config_id": 5, - "vision_llm_config_id": -1, + "chat_model_id": -1, + "image_gen_model_id": 5, + "vision_model_id": -1, }, }, version=3, ) validated: dict[str, Any] = {} - def _assert_ok(*, agent_llm_id, image_generation_config_id, vision_llm_config_id): + def _assert_ok(*, chat_model_id, image_gen_model_id, vision_model_id): validated["ids"] = ( - agent_llm_id, - image_generation_config_id, - vision_llm_config_id, + chat_model_id, + image_gen_model_id, + vision_model_id, ) monkeypatch.setattr(automation_mod, "assert_models_billable", _assert_ok) @@ -351,9 +351,9 @@ async def test_update_honors_changed_models_when_valid( patch = AutomationUpdate( definition=_definition( models=AutomationModels( - agent_llm_id=-2, - image_generation_config_id=9, - vision_llm_config_id=-2, + chat_model_id=-2, + image_gen_model_id=9, + vision_model_id=-2, ) ) ) @@ -362,9 +362,9 @@ async def test_update_honors_changed_models_when_valid( assert validated["ids"] == (-2, 9, -2) assert result.definition["models"] == { - "agent_llm_id": -2, - "image_generation_config_id": 9, - "vision_llm_config_id": -2, + "chat_model_id": -2, + "image_gen_model_id": 9, + "vision_model_id": -2, } assert result.version == 4 @@ -379,17 +379,17 @@ async def test_update_rejects_changed_unbillable_models( "name": "A", "plan": [], "models": { - "agent_llm_id": -1, - "image_generation_config_id": 5, - "vision_llm_config_id": -1, + "chat_model_id": -1, + "image_gen_model_id": 5, + "vision_model_id": -1, }, }, version=3, ) - def _raise(*, agent_llm_id, image_generation_config_id, vision_llm_config_id): + def _raise(*, chat_model_id, image_gen_model_id, vision_model_id): raise AutomationModelPolicyError( - [{"kind": "llm", "config_id": -7, "reason": "free model"}] + [{"kind": "llm", "model_id": -7, "reason": "free model"}] ) monkeypatch.setattr(automation_mod, "assert_models_billable", _raise) @@ -409,9 +409,9 @@ async def test_update_rejects_changed_unbillable_models( patch = AutomationUpdate( definition=_definition( models=AutomationModels( - agent_llm_id=-7, - image_generation_config_id=5, - vision_llm_config_id=-1, + chat_model_id=-7, + image_gen_model_id=5, + vision_model_id=-1, ) ) ) @@ -431,9 +431,9 @@ async def test_update_keeps_unchanged_models_without_revalidation( premium without an unrelated edit tripping the policy check. """ captured = { - "agent_llm_id": -1, - "image_generation_config_id": 5, - "vision_llm_config_id": -1, + "chat_model_id": -1, + "image_gen_model_id": 5, + "vision_model_id": -1, } existing = SimpleNamespace( search_space_id=1, @@ -485,7 +485,7 @@ async def test_model_eligibility_authorizes_and_returns_payload( lambda _ss: {"allowed": False, "violations": [{"kind": "image"}]}, ) - service = _service(SimpleNamespace(agent_llm_id=-2)) + service = _service(SimpleNamespace(chat_model_id=-2)) result = await service.model_eligibility(search_space_id=5) assert result == {"allowed": False, "violations": [{"kind": "image"}]} diff --git a/surfsense_backend/tests/unit/automations/services/test_model_policy.py b/surfsense_backend/tests/unit/automations/services/test_model_policy.py index 8e0806151..574f6d9fd 100644 --- a/surfsense_backend/tests/unit/automations/services/test_model_policy.py +++ b/surfsense_backend/tests/unit/automations/services/test_model_policy.py @@ -27,9 +27,9 @@ pytestmark = pytest.mark.unit def _search_space(*, llm: int | None, image: int | None, vision: int | None): """Minimal stand-in for the ``SearchSpace`` ORM row the policy reads.""" return SimpleNamespace( - agent_llm_id=llm, - image_generation_config_id=image, - vision_llm_config_id=vision, + chat_model_id=llm, + image_gen_model_id=image, + vision_model_id=vision, ) @@ -39,29 +39,11 @@ def patched_globals(monkeypatch: pytest.MonkeyPatch): Negative ids: -1 is premium, -2 is free, for each of llm/image/vision. """ - llm_configs = { - -1: {"id": -1, "billing_tier": "premium"}, - -2: {"id": -2, "billing_tier": "free"}, - } - monkeypatch.setattr( - "app.agents.chat.runtime.llm_config.load_global_llm_config_by_id", - lambda cid: llm_configs.get(cid), - ) - from app.config import config as app_config monkeypatch.setattr( app_config, - "GLOBAL_IMAGE_GEN_CONFIGS", - [ - {"id": -1, "billing_tier": "premium"}, - {"id": -2, "billing_tier": "free"}, - ], - raising=False, - ) - monkeypatch.setattr( - app_config, - "GLOBAL_VISION_LLM_CONFIGS", + "GLOBAL_MODELS", [ {"id": -1, "billing_tier": "premium"}, {"id": -2, "billing_tier": "free"}, @@ -71,7 +53,7 @@ def patched_globals(monkeypatch: pytest.MonkeyPatch): return None -@pytest.mark.parametrize("kind", ["llm", "image", "vision"]) +@pytest.mark.parametrize("kind", ["chat", "image", "vision"]) def test_byok_positive_id_is_allowed(kind: str, patched_globals) -> None: """A positive config id is a user-owned BYOK model — always billable.""" allowed, reason = model_policy._classify(kind, 7) @@ -79,7 +61,7 @@ def test_byok_positive_id_is_allowed(kind: str, patched_globals) -> None: assert reason == "" -@pytest.mark.parametrize("kind", ["llm", "image", "vision"]) +@pytest.mark.parametrize("kind", ["chat", "image", "vision"]) @pytest.mark.parametrize("config_id", [0, None]) def test_auto_mode_is_blocked(kind: str, config_id, patched_globals) -> None: """Auto mode (id 0) and an unset slot (None) are blocked.""" @@ -88,7 +70,7 @@ def test_auto_mode_is_blocked(kind: str, config_id, patched_globals) -> None: assert "Auto mode" in reason -@pytest.mark.parametrize("kind", ["llm", "image", "vision"]) +@pytest.mark.parametrize("kind", ["chat", "image", "vision"]) def test_premium_global_is_allowed(kind: str, patched_globals) -> None: """A negative (global) id with premium billing tier is allowed.""" allowed, reason = model_policy._classify(kind, -1) @@ -96,7 +78,7 @@ def test_premium_global_is_allowed(kind: str, patched_globals) -> None: assert reason == "" -@pytest.mark.parametrize("kind", ["llm", "image", "vision"]) +@pytest.mark.parametrize("kind", ["chat", "image", "vision"]) def test_free_global_is_blocked(kind: str, patched_globals) -> None: """A negative (global) id with a free billing tier is blocked.""" allowed, reason = model_policy._classify(kind, -2) @@ -104,7 +86,7 @@ def test_free_global_is_blocked(kind: str, patched_globals) -> None: assert "free model" in reason -@pytest.mark.parametrize("kind", ["llm", "image", "vision"]) +@pytest.mark.parametrize("kind", ["chat", "image", "vision"]) def test_unknown_global_id_is_blocked(kind: str, patched_globals) -> None: """A negative id that resolves to no config is treated as not premium.""" allowed, _ = model_policy._classify(kind, -999) @@ -125,10 +107,10 @@ def test_eligibility_reports_each_violation(patched_globals) -> None: assert result["allowed"] is False kinds = {v["kind"] for v in result["violations"]} - assert kinds == {"llm", "image", "vision"} - # config_id is echoed back for the UI / settings deep-link. - by_kind = {v["kind"]: v["config_id"] for v in result["violations"]} - assert by_kind == {"llm": -2, "image": 0, "vision": -2} + assert kinds == {"chat", "image", "vision"} + # model_id is echoed back for the UI / settings deep-link. + by_kind = {v["kind"]: v["model_id"] for v in result["violations"]} + assert by_kind == {"chat": -2, "image": 0, "vision": -2} def test_assert_raises_with_violations(patched_globals) -> None: @@ -138,7 +120,7 @@ def test_assert_raises_with_violations(patched_globals) -> None: assert_automation_models_billable(search_space) assert len(exc_info.value.violations) == 1 - assert exc_info.value.violations[0]["kind"] == "llm" + assert exc_info.value.violations[0]["kind"] == "chat" def test_assert_passes_when_all_billable(patched_globals) -> None: @@ -153,7 +135,7 @@ def test_assert_passes_when_all_billable(patched_globals) -> None: def test_get_model_eligibility_all_billable(patched_globals) -> None: """Premium LLM + BYOK image + premium vision (explicit ids) → allowed.""" result = get_model_eligibility( - agent_llm_id=-1, image_generation_config_id=5, vision_llm_config_id=-1 + chat_model_id=-1, image_gen_model_id=5, vision_model_id=-1 ) assert result == {"allowed": True, "violations": []} @@ -161,28 +143,28 @@ def test_get_model_eligibility_all_billable(patched_globals) -> None: def test_get_model_eligibility_reports_each_violation(patched_globals) -> None: """Free LLM, Auto image, free vision (explicit ids) each produce a violation.""" result = get_model_eligibility( - agent_llm_id=-2, image_generation_config_id=0, vision_llm_config_id=-2 + chat_model_id=-2, image_gen_model_id=0, vision_model_id=-2 ) assert result["allowed"] is False - by_kind = {v["kind"]: v["config_id"] for v in result["violations"]} - assert by_kind == {"llm": -2, "image": 0, "vision": -2} + by_kind = {v["kind"]: v["model_id"] for v in result["violations"]} + assert by_kind == {"chat": -2, "image": 0, "vision": -2} def test_assert_models_billable_raises(patched_globals) -> None: """``assert_models_billable`` raises when any explicit id is blocked.""" with pytest.raises(AutomationModelPolicyError) as exc_info: assert_models_billable( - agent_llm_id=0, image_generation_config_id=5, vision_llm_config_id=-1 + chat_model_id=0, image_gen_model_id=5, vision_model_id=-1 ) assert len(exc_info.value.violations) == 1 - assert exc_info.value.violations[0]["kind"] == "llm" + assert exc_info.value.violations[0]["kind"] == "chat" def test_assert_models_billable_passes(patched_globals) -> None: """No exception when every explicit id is premium or BYOK.""" assert ( assert_models_billable( - agent_llm_id=3, image_generation_config_id=-1, vision_llm_config_id=4 + chat_model_id=3, image_gen_model_id=-1, vision_model_id=4 ) is None ) @@ -192,5 +174,5 @@ def test_search_space_wrapper_delegates_to_core(patched_globals) -> None: """The search-space wrapper produces the same result as the ID core.""" search_space = _search_space(llm=-2, image=0, vision=-2) assert get_automation_model_eligibility(search_space) == get_model_eligibility( - agent_llm_id=-2, image_generation_config_id=0, vision_llm_config_id=-2 + chat_model_id=-2, image_gen_model_id=0, vision_model_id=-2 ) diff --git a/surfsense_backend/tests/unit/routes/test_byok_supports_image_input.py b/surfsense_backend/tests/unit/routes/test_byok_supports_image_input.py deleted file mode 100644 index c9f18d77d..000000000 --- a/surfsense_backend/tests/unit/routes/test_byok_supports_image_input.py +++ /dev/null @@ -1,110 +0,0 @@ -"""Unit tests for ``supports_image_input`` derivation on BYOK chat config -endpoints (``GET /new-llm-configs`` list, ``GET /new-llm-configs/{id}``). - -There is no DB column for ``supports_image_input`` on -``NewLLMConfig`` — the value is resolved at the API boundary by -``derive_supports_image_input`` so the new-chat selector / streaming -task can read the same field shape regardless of source (BYOK vs YAML -vs OpenRouter dynamic). Default-allow on unknown so we don't lock the -user out of their own model choice. -""" - -from __future__ import annotations - -from datetime import UTC, datetime -from types import SimpleNamespace -from uuid import uuid4 - -import pytest - -from app.db import LiteLLMProvider -from app.routes import new_llm_config_routes - -pytestmark = pytest.mark.unit - - -def _byok_row( - *, - id_: int, - model_name: str, - base_model: str | None = None, - provider: LiteLLMProvider = LiteLLMProvider.OPENAI, - custom_provider: str | None = None, -) -> object: - """Mimic the SQLAlchemy row's attribute surface; ``model_validate`` - walks ``from_attributes=True`` so a ``SimpleNamespace`` is enough. - - ``provider`` is a real ``LiteLLMProvider`` enum value so Pydantic's - enum validator accepts it — same as the ORM row would carry.""" - return SimpleNamespace( - id=id_, - name=f"BYOK-{id_}", - description=None, - provider=provider, - custom_provider=custom_provider, - model_name=model_name, - api_key="sk-byok", - api_base=None, - litellm_params={"base_model": base_model} if base_model else None, - system_instructions="", - use_default_system_instructions=True, - citations_enabled=True, - created_at=datetime.now(tz=UTC), - search_space_id=42, - user_id=uuid4(), - ) - - -def test_serialize_byok_known_vision_model_resolves_true(): - """The catalog resolver consults LiteLLM's map for ``gpt-4o`` -> - True. The serialized row carries that value through to the - ``NewLLMConfigRead`` schema.""" - row = _byok_row(id_=1, model_name="gpt-4o") - serialized = new_llm_config_routes._serialize_byok_config(row) - - assert serialized.supports_image_input is True - assert serialized.id == 1 - assert serialized.model_name == "gpt-4o" - - -def test_serialize_byok_unknown_model_default_allows(): - """Unknown / unmapped: default-allow. The streaming-task safety net - is the actual block, and it requires LiteLLM to *explicitly* say - text-only — so a brand new BYOK model should not be pre-judged.""" - row = _byok_row( - id_=2, - model_name="brand-new-model-x9-unmapped", - provider=LiteLLMProvider.CUSTOM, - custom_provider="brand_new_proxy", - ) - serialized = new_llm_config_routes._serialize_byok_config(row) - - assert serialized.supports_image_input is True - - -def test_serialize_byok_uses_base_model_when_present(): - """Azure-style: ``model_name`` is the deployment id, ``base_model`` - inside ``litellm_params`` is the canonical sku LiteLLM knows. The - helper must consult ``base_model`` first or unrecognised deployment - ids would shadow the real capability.""" - row = _byok_row( - id_=3, - model_name="my-azure-deployment-id-no-litellm-knows-this", - base_model="gpt-4o", - provider=LiteLLMProvider.AZURE_OPENAI, - ) - serialized = new_llm_config_routes._serialize_byok_config(row) - - assert serialized.supports_image_input is True - - -def test_serialize_byok_returns_pydantic_read_model(): - """The route now returns ``NewLLMConfigRead`` (not the raw ORM) so - the schema additions are guaranteed to be present in the API - surface. This guards against a future regression where someone - deletes the augmentation step and falls back to ORM passthrough.""" - from app.schemas import NewLLMConfigRead - - row = _byok_row(id_=4, model_name="gpt-4o") - serialized = new_llm_config_routes._serialize_byok_config(row) - assert isinstance(serialized, NewLLMConfigRead) diff --git a/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py b/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py deleted file mode 100644 index fff61f14b..000000000 --- a/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py +++ /dev/null @@ -1,184 +0,0 @@ -"""Unit tests for ``is_premium`` derivation on the global image-gen and -vision-LLM list endpoints. - -Chat globals (``GET /global-llm-configs``) already emit -``is_premium = (billing_tier == "premium")``. Image and vision did not, -which made the new-chat ``model-selector`` render the Free/Premium badge -on the Chat tab but skip it on the Image and Vision tabs (the selector -keys its badge logic off ``is_premium``). These tests pin parity: - -* YAML free entry → ``is_premium=False`` -* YAML premium entry → ``is_premium=True`` -* OpenRouter dynamic premium entry → ``is_premium=True`` -* Auto stub (always emitted when at least one config is present) - → ``is_premium=False`` -""" - -from __future__ import annotations - -import pytest - -pytestmark = pytest.mark.unit - - -_IMAGE_FIXTURE: list[dict] = [ - { - "id": -1, - "name": "DALL-E 3", - "litellm_provider": "openai", - "model_name": "dall-e-3", - "api_key": "sk-test", - "billing_tier": "free", - }, - { - "id": -2, - "name": "GPT-Image 1 (premium)", - "litellm_provider": "openai", - "model_name": "gpt-image-1", - "api_key": "sk-test", - "billing_tier": "premium", - }, - { - "id": -20_001, - "name": "google/gemini-2.5-flash-image (OpenRouter)", - "litellm_provider": "openrouter", - "model_name": "google/gemini-2.5-flash-image", - "api_key": "sk-or-test", - "api_base": "https://openrouter.ai/api/v1", - "billing_tier": "premium", - }, -] - - -_VISION_FIXTURE: list[dict] = [ - { - "id": -1, - "name": "GPT-4o Vision", - "litellm_provider": "openai", - "model_name": "gpt-4o", - "api_key": "sk-test", - "billing_tier": "free", - }, - { - "id": -2, - "name": "Claude 3.5 Sonnet (premium)", - "litellm_provider": "anthropic", - "model_name": "claude-3-5-sonnet", - "api_key": "sk-ant-test", - "billing_tier": "premium", - }, - { - "id": -30_001, - "name": "openai/gpt-4o (OpenRouter)", - "litellm_provider": "openrouter", - "model_name": "openai/gpt-4o", - "api_key": "sk-or-test", - "api_base": "https://openrouter.ai/api/v1", - "billing_tier": "premium", - }, -] - - -# ============================================================================= -# Image generation -# ============================================================================= - - -@pytest.mark.asyncio -async def test_global_image_gen_configs_emit_is_premium(monkeypatch): - """Each emitted config must carry ``is_premium`` derived server-side - from ``billing_tier``. The Auto stub is always free. - """ - from app.config import config - from app.routes import image_generation_routes - - monkeypatch.setattr( - config, "GLOBAL_IMAGE_GEN_CONFIGS", _IMAGE_FIXTURE, raising=False - ) - - payload = await image_generation_routes.get_global_image_gen_configs(user=None) - - by_id = {c["id"]: c for c in payload} - - # Auto stub is always emitted when at least one global config exists, - # and it must always declare itself free (Auto-mode billing-tier - # surfacing is a separate follow-up). - assert 0 in by_id, "Auto stub should be emitted when at least one config exists" - assert by_id[0]["is_premium"] is False - assert by_id[0]["billing_tier"] == "free" - - # YAML free entry — ``is_premium=False`` - assert by_id[-1]["is_premium"] is False - assert by_id[-1]["billing_tier"] == "free" - - # YAML premium entry — ``is_premium=True`` - assert by_id[-2]["is_premium"] is True - assert by_id[-2]["billing_tier"] == "premium" - - # OpenRouter dynamic premium entry — same field, same derivation - assert by_id[-20_001]["is_premium"] is True - assert by_id[-20_001]["billing_tier"] == "premium" - - # Every emitted dict (including Auto) must have the field — never missing. - for cfg in payload: - assert "is_premium" in cfg, f"is_premium missing from {cfg.get('id')}" - assert isinstance(cfg["is_premium"], bool) - - -@pytest.mark.asyncio -async def test_global_image_gen_configs_no_globals_no_auto_stub(monkeypatch): - """When there are no global configs at all, the endpoint emits an - empty list (no Auto stub) — Auto mode would have nothing to route to. - """ - from app.config import config - from app.routes import image_generation_routes - - monkeypatch.setattr(config, "GLOBAL_IMAGE_GEN_CONFIGS", [], raising=False) - payload = await image_generation_routes.get_global_image_gen_configs(user=None) - assert payload == [] - - -# ============================================================================= -# Vision LLM -# ============================================================================= - - -@pytest.mark.asyncio -async def test_global_vision_llm_configs_emit_is_premium(monkeypatch): - from app.config import config - from app.routes import vision_llm_routes - - monkeypatch.setattr( - config, "GLOBAL_VISION_LLM_CONFIGS", _VISION_FIXTURE, raising=False - ) - - payload = await vision_llm_routes.get_global_vision_llm_configs(user=None) - - by_id = {c["id"]: c for c in payload} - - assert 0 in by_id, "Auto stub should be emitted when at least one config exists" - assert by_id[0]["is_premium"] is False - assert by_id[0]["billing_tier"] == "free" - - assert by_id[-1]["is_premium"] is False - assert by_id[-1]["billing_tier"] == "free" - - assert by_id[-2]["is_premium"] is True - assert by_id[-2]["billing_tier"] == "premium" - - assert by_id[-30_001]["is_premium"] is True - assert by_id[-30_001]["billing_tier"] == "premium" - - for cfg in payload: - assert "is_premium" in cfg, f"is_premium missing from {cfg.get('id')}" - assert isinstance(cfg["is_premium"], bool) - - -@pytest.mark.asyncio -async def test_global_vision_llm_configs_no_globals_no_auto_stub(monkeypatch): - from app.config import config - from app.routes import vision_llm_routes - - monkeypatch.setattr(config, "GLOBAL_VISION_LLM_CONFIGS", [], raising=False) - payload = await vision_llm_routes.get_global_vision_llm_configs(user=None) - assert payload == [] diff --git a/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py b/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py deleted file mode 100644 index 67d1112f3..000000000 --- a/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py +++ /dev/null @@ -1,106 +0,0 @@ -"""Unit tests for ``supports_image_input`` derivation on the chat global -config endpoint (``GET /global-new-llm-configs``). - -Resolution order (matches ``new_llm_config_routes.get_global_new_llm_configs``): - -1. Explicit ``supports_image_input`` on the cfg dict (set by the YAML - loader for operator overrides, or by the OpenRouter integration from - ``architecture.input_modalities``) — wins. -2. ``derive_supports_image_input`` helper — default-allow on unknown - models, only False when LiteLLM / OR modalities are definitive. - -The flag is purely informational at the API boundary. The streaming -task safety net (``is_known_text_only_chat_model``) is the actual block, -and it requires LiteLLM to *explicitly* mark the model as text-only. -""" - -from __future__ import annotations - -import pytest - -pytestmark = pytest.mark.unit - - -_FIXTURE: list[dict] = [ - { - "id": -1, - "name": "GPT-4o (explicit true)", - "description": "vision-capable, explicit YAML override", - "litellm_provider": "openai", - "model_name": "gpt-4o", - "api_key": "sk-test", - "billing_tier": "free", - "supports_image_input": True, - }, - { - "id": -2, - "name": "DeepSeek V3 (explicit false)", - "description": "OpenRouter dynamic — modality-derived false", - "litellm_provider": "openrouter", - "model_name": "deepseek/deepseek-v3.2-exp", - "api_key": "sk-or-test", - "api_base": "https://openrouter.ai/api/v1", - "billing_tier": "free", - "supports_image_input": False, - }, - { - "id": -10_010, - "name": "Unannotated GPT-4o", - "description": "no flag set — resolver should derive True via LiteLLM", - "litellm_provider": "openai", - "model_name": "gpt-4o", - "api_key": "sk-test", - "billing_tier": "free", - # supports_image_input intentionally absent - }, - { - "id": -10_011, - "name": "Unannotated unknown model", - "description": "unmapped — default-allow True", - "litellm_provider": "custom", - "custom_provider": "brand_new_proxy", - "model_name": "brand-new-model-x9", - "api_key": "sk-test", - "billing_tier": "free", - }, -] - - -@pytest.mark.asyncio -async def test_global_new_llm_configs_emit_supports_image_input(monkeypatch): - """Each emitted chat config carries ``supports_image_input`` as a - bool. Explicit values win; unannotated entries are resolved via the - helper (default-allow True).""" - from app.config import config - from app.routes import new_llm_config_routes - - monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", _FIXTURE, raising=False) - - payload = await new_llm_config_routes.get_global_new_llm_configs(user=None) - by_id = {c["id"]: c for c in payload} - - # Auto stub: optimistic True so the user can keep Auto selected with - # vision-capable deployments somewhere in the pool. - assert 0 in by_id, "Auto stub should be emitted when configs exist" - assert by_id[0]["supports_image_input"] is True - assert by_id[0]["is_auto_mode"] is True - - # Explicit True is preserved. - assert by_id[-1]["supports_image_input"] is True - - # Explicit False is preserved (the exact failure mode the safety net - # guards against — DeepSeek V3 over OpenRouter would 404 with "No - # endpoints found that support image input"). - assert by_id[-2]["supports_image_input"] is False - - # Unannotated GPT-4o: resolver consults LiteLLM, which says vision. - assert by_id[-10_010]["supports_image_input"] is True - - # Unknown / unmapped model: default-allow rather than pre-judge. - assert by_id[-10_011]["supports_image_input"] is True - - for cfg in payload: - assert "supports_image_input" in cfg, ( - f"supports_image_input missing from {cfg.get('id')}" - ) - assert isinstance(cfg["supports_image_input"], bool) diff --git a/surfsense_backend/tests/unit/routes/test_image_gen_quota.py b/surfsense_backend/tests/unit/routes/test_image_gen_quota.py index 53c0f50a9..4dd918927 100644 --- a/surfsense_backend/tests/unit/routes/test_image_gen_quota.py +++ b/surfsense_backend/tests/unit/routes/test_image_gen_quota.py @@ -27,9 +27,18 @@ async def test_resolve_billing_for_auto_mode(monkeypatch): from app.routes import image_generation_routes from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS - search_space = SimpleNamespace(image_generation_config_id=None) + async def _no_auto_candidates(*_args, **_kwargs): + return [] + + monkeypatch.setattr( + image_generation_routes, + "auto_model_candidates", + _no_auto_candidates, + ) + + search_space = SimpleNamespace(id=1, user_id=None, image_gen_model_id=None) tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen( - session=None, # Not consumed on this code path. + session=None, config_id=0, # IMAGE_GEN_AUTO_MODE_ID search_space=search_space, ) @@ -45,26 +54,42 @@ async def test_resolve_billing_for_premium_global_config(monkeypatch): monkeypatch.setattr( config, - "GLOBAL_IMAGE_GEN_CONFIGS", + "GLOBAL_MODELS", [ { "id": -1, - "litellm_provider": "openai", - "model_name": "gpt-image-1", + "connection_id": -101, + "model_id": "gpt-image-1", "billing_tier": "premium", - "quota_reserve_micros": 75_000, + "catalog": {"quota_reserve_micros": 75_000}, }, { "id": -2, - "litellm_provider": "openrouter", - "model_name": "google/gemini-2.5-flash-image", + "connection_id": -102, + "model_id": "google/gemini-2.5-flash-image", "billing_tier": "free", + "catalog": {}, + }, + ], + raising=False, + ) + monkeypatch.setattr( + config, + "GLOBAL_CONNECTIONS", + [ + {"id": -101, "provider": "openai", "api_key": "sk-test", "base_url": None, "extra": {}}, + { + "id": -102, + "provider": "openrouter", + "api_key": "sk-or-test", + "base_url": "https://openrouter.ai/api/v1", + "extra": {}, }, ], raising=False, ) - search_space = SimpleNamespace(image_generation_config_id=None) + search_space = SimpleNamespace(id=1, user_id=None, image_gen_model_id=None) # Premium with override. tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen( @@ -94,7 +119,7 @@ async def test_resolve_billing_for_user_owned_byok_is_free(): from app.routes import image_generation_routes from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS - search_space = SimpleNamespace(image_generation_config_id=None) + search_space = SimpleNamespace(id=1, user_id=None, image_gen_model_id=None) tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen( session=None, config_id=42, search_space=search_space ) @@ -105,7 +130,7 @@ async def test_resolve_billing_for_user_owned_byok_is_free(): @pytest.mark.asyncio async def test_resolve_billing_falls_back_to_search_space_default(monkeypatch): - """When the request omits ``image_generation_config_id``, the helper + """When the request omits ``image_gen_model_id``, the helper must consult the search space's default — so a search space pinned to a premium global config still gates new requests by quota. """ @@ -114,19 +139,26 @@ async def test_resolve_billing_falls_back_to_search_space_default(monkeypatch): monkeypatch.setattr( config, - "GLOBAL_IMAGE_GEN_CONFIGS", + "GLOBAL_MODELS", [ { "id": -7, - "litellm_provider": "openai", - "model_name": "gpt-image-1", + "connection_id": -101, + "model_id": "gpt-image-1", "billing_tier": "premium", + "catalog": {}, } ], raising=False, ) + monkeypatch.setattr( + config, + "GLOBAL_CONNECTIONS", + [{"id": -101, "provider": "openai", "api_key": "sk-test", "base_url": None, "extra": {}}], + raising=False, + ) - search_space = SimpleNamespace(image_generation_config_id=-7) + search_space = SimpleNamespace(id=1, user_id=None, image_gen_model_id=-7) ( tier, model, diff --git a/surfsense_backend/tests/unit/services/test_agent_billing_resolver.py b/surfsense_backend/tests/unit/services/test_agent_billing_resolver.py index fa8819b39..b43540ba7 100644 --- a/surfsense_backend/tests/unit/services/test_agent_billing_resolver.py +++ b/surfsense_backend/tests/unit/services/test_agent_billing_resolver.py @@ -1,27 +1,4 @@ -"""Unit tests for ``_resolve_agent_billing_for_search_space``. - -Validates the resolver used by Celery podcast/video tasks to compute -``(owner_user_id, billing_tier, base_model)`` from a search space and its -agent LLM config. The resolver mirrors chat's billing-resolution pattern at -``stream_new_chat.py:2294-2351`` and is the single integration point that -prevents Auto-mode podcast/video from leaking premium credit. - -Coverage: - -* Auto mode + ``thread_id`` set, pin resolves to a negative-id premium - global → returns ``("premium", )``. -* Auto mode + ``thread_id`` set, pin resolves to a negative-id free - global → returns ``("free", )``. -* Auto mode + ``thread_id`` set, pin resolves to a positive-id BYOK config - → always ``"free"``. -* Auto mode + ``thread_id=None`` → fallback to ``("free", "auto")`` without - hitting the pin service. -* Negative id (no Auto) → uses ``get_global_llm_config``'s - ``billing_tier``. -* Positive id (user BYOK) → always ``"free"``. -* Search space not found → raises ``ValueError``. -* ``agent_llm_id`` is None → raises ``ValueError``. -""" +"""Unit tests for ``_resolve_agent_billing_for_search_space``.""" from __future__ import annotations @@ -34,11 +11,6 @@ import pytest pytestmark = pytest.mark.unit -# --------------------------------------------------------------------------- -# Fakes -# --------------------------------------------------------------------------- - - class _FakeExecResult: def __init__(self, obj): self._obj = obj @@ -51,14 +23,6 @@ class _FakeExecResult: class _FakeSession: - """Tiny AsyncSession stub. - - ``responses`` is a list of objects to return from successive - ``execute()`` calls (in order). The resolver makes at most two - ``execute()`` calls (search-space lookup, then optionally NewLLMConfig - lookup), so two queued responses cover the matrix. - """ - def __init__(self, responses: list): self._responses = list(responses) @@ -67,9 +31,6 @@ class _FakeSession: return _FakeExecResult(None) return _FakeExecResult(self._responses.pop(0)) - async def commit(self) -> None: - pass - @dataclass class _FakePinResolution: @@ -78,53 +39,33 @@ class _FakePinResolution: from_existing_pin: bool = False -def _make_search_space(*, agent_llm_id: int | None, user_id: UUID) -> SimpleNamespace: - return SimpleNamespace( - id=42, - agent_llm_id=agent_llm_id, - user_id=user_id, - ) +def _make_search_space(*, chat_model_id: int | None, user_id: UUID) -> SimpleNamespace: + return SimpleNamespace(id=42, chat_model_id=chat_model_id, user_id=user_id) -def _make_byok_config( - *, id_: int, base_model: str | None = None, model_name: str = "gpt-byok" +def _make_byok_model( + *, id_: int, base_model: str | None = None, model_id: str = "gpt-byok" ) -> SimpleNamespace: return SimpleNamespace( id=id_, - model_name=model_name, - litellm_params={"base_model": base_model} if base_model else {}, + model_id=model_id, + catalog={"base_model": base_model} if base_model else {}, + connection=SimpleNamespace(enabled=True, search_space_id=42, user_id=None), ) -# --------------------------------------------------------------------------- -# Tests -# --------------------------------------------------------------------------- - - @pytest.mark.asyncio async def test_auto_mode_with_thread_id_resolves_to_premium_global(monkeypatch): - """Auto + thread → pin service resolves to negative-id premium config → - resolver returns ``("premium", )``.""" from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)]) + session = _FakeSession([_make_search_space(chat_model_id=0, user_id=user_id)]) - # Mock the pin service to return a concrete premium config id. - async def _fake_resolve_pin( - sess, - *, - thread_id, - search_space_id, - user_id, - selected_llm_config_id, - force_repin_free=False, - ): - assert selected_llm_config_id == 0 - assert thread_id == 99 + async def _fake_resolve_pin(*_args, **kwargs): + assert kwargs["selected_llm_config_id"] == 0 + assert kwargs["thread_id"] == 99 return _FakePinResolution(resolved_llm_config_id=-1, resolved_tier="premium") - # Mock global config lookup to return a premium entry. def _fake_get_global(cfg_id): if cfg_id == -1: return { @@ -135,8 +76,6 @@ async def test_auto_mode_with_thread_id_resolves_to_premium_global(monkeypatch): } return None - # Lazy imports inside the resolver — patch the *target* modules so the - # imported names resolve to our fakes. import app.services.auto_model_pin_service as pin_module import app.services.llm_service as llm_module @@ -154,77 +93,18 @@ async def test_auto_mode_with_thread_id_resolves_to_premium_global(monkeypatch): assert base_model == "gpt-5.4" -@pytest.mark.asyncio -async def test_auto_mode_with_thread_id_resolves_to_free_global(monkeypatch): - """Auto + thread → pin returns negative-id free config → resolver - returns ``("free", )``. Same path the pin service takes for - out-of-credit users (graceful degradation).""" - from app.services.billable_calls import _resolve_agent_billing_for_search_space - - user_id = uuid4() - session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)]) - - async def _fake_resolve_pin( - sess, - *, - thread_id, - search_space_id, - user_id, - selected_llm_config_id, - force_repin_free=False, - ): - return _FakePinResolution(resolved_llm_config_id=-3, resolved_tier="free") - - def _fake_get_global(cfg_id): - if cfg_id == -3: - return { - "id": -3, - "model_name": "openrouter/free-model", - "billing_tier": "free", - "litellm_params": {"base_model": "openrouter/free-model"}, - } - return None - - import app.services.auto_model_pin_service as pin_module - import app.services.llm_service as llm_module - - monkeypatch.setattr( - pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin - ) - monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global) - - owner, tier, base_model = await _resolve_agent_billing_for_search_space( - session, search_space_id=42, thread_id=99 - ) - - assert owner == user_id - assert tier == "free" - assert base_model == "openrouter/free-model" - - @pytest.mark.asyncio async def test_auto_mode_with_thread_id_resolves_to_byok_is_free(monkeypatch): - """Auto + thread → pin returns positive-id BYOK config → resolver - returns ``("free", ...)`` (BYOK is always free per - ``AgentConfig.from_new_llm_config``).""" from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - search_space = _make_search_space(agent_llm_id=0, user_id=user_id) - byok_cfg = _make_byok_config( - id_=17, base_model="anthropic/claude-3-haiku", model_name="my-claude" + search_space = _make_search_space(chat_model_id=0, user_id=user_id) + byok_model = _make_byok_model( + id_=17, base_model="anthropic/claude-3-haiku", model_id="my-claude" ) - session = _FakeSession([search_space, byok_cfg]) + session = _FakeSession([search_space, byok_model]) - async def _fake_resolve_pin( - sess, - *, - thread_id, - search_space_id, - user_id, - selected_llm_config_id, - force_repin_free=False, - ): + async def _fake_resolve_pin(*_args, **_kwargs): return _FakePinResolution(resolved_llm_config_id=17, resolved_tier="free") import app.services.auto_model_pin_service as pin_module @@ -244,13 +124,10 @@ async def test_auto_mode_with_thread_id_resolves_to_byok_is_free(monkeypatch): @pytest.mark.asyncio async def test_auto_mode_without_thread_id_falls_back_to_free(): - """Auto + ``thread_id=None`` → ``("free", "auto")`` without invoking - the pin service. Forward-compat fallback for any future direct-API - entrypoint that doesn't have a chat thread.""" from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)]) + session = _FakeSession([_make_search_space(chat_model_id=0, user_id=user_id)]) owner, tier, base_model = await _resolve_agent_billing_for_search_space( session, search_space_id=42, thread_id=None @@ -263,13 +140,10 @@ async def test_auto_mode_without_thread_id_falls_back_to_free(): @pytest.mark.asyncio async def test_auto_mode_pin_failure_falls_back_to_free(monkeypatch): - """If the pin service raises ``ValueError`` (thread missing / - mismatched search space), the resolver should log and return free - rather than killing the whole task.""" from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)]) + session = _FakeSession([_make_search_space(chat_model_id=0, user_id=user_id)]) async def _fake_resolve_pin(*args, **kwargs): raise ValueError("thread missing") @@ -291,12 +165,10 @@ async def test_auto_mode_pin_failure_falls_back_to_free(monkeypatch): @pytest.mark.asyncio async def test_negative_id_premium_global_returns_premium(monkeypatch): - """Explicit negative agent_llm_id → ``get_global_llm_config`` → - return its ``billing_tier``.""" from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - session = _FakeSession([_make_search_space(agent_llm_id=-1, user_id=user_id)]) + session = _FakeSession([_make_search_space(chat_model_id=-1, user_id=user_id)]) def _fake_get_global(cfg_id): return { @@ -319,50 +191,15 @@ async def test_negative_id_premium_global_returns_premium(monkeypatch): assert base_model == "gpt-5.4" -@pytest.mark.asyncio -async def test_negative_id_free_global_returns_free(monkeypatch): - from app.services.billable_calls import _resolve_agent_billing_for_search_space - - user_id = uuid4() - session = _FakeSession([_make_search_space(agent_llm_id=-2, user_id=user_id)]) - - def _fake_get_global(cfg_id): - return { - "id": cfg_id, - "model_name": "openrouter/some-free", - "billing_tier": "free", - "litellm_params": {"base_model": "openrouter/some-free"}, - } - - import app.services.llm_service as llm_module - - monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global) - - owner, tier, base_model = await _resolve_agent_billing_for_search_space( - session, search_space_id=42, thread_id=None - ) - - assert owner == user_id - assert tier == "free" - assert base_model == "openrouter/some-free" - - @pytest.mark.asyncio async def test_negative_id_missing_base_model_falls_back_to_model_name(monkeypatch): - """When the global config has no ``litellm_params.base_model``, the - resolver falls back to ``model_name`` — matching chat's behavior.""" from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - session = _FakeSession([_make_search_space(agent_llm_id=-5, user_id=user_id)]) + session = _FakeSession([_make_search_space(chat_model_id=-5, user_id=user_id)]) def _fake_get_global(cfg_id): - return { - "id": cfg_id, - "model_name": "fallback-model", - "billing_tier": "premium", - # No litellm_params. - } + return {"id": cfg_id, "model_name": "fallback-model", "billing_tier": "premium"} import app.services.llm_service as llm_module @@ -378,14 +215,12 @@ async def test_negative_id_missing_base_model_falls_back_to_model_name(monkeypat @pytest.mark.asyncio async def test_positive_id_byok_is_always_free(): - """Positive agent_llm_id → user-owned BYOK NewLLMConfig → always free, - regardless of underlying provider tier.""" from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - search_space = _make_search_space(agent_llm_id=23, user_id=user_id) - byok_cfg = _make_byok_config(id_=23, base_model="anthropic/claude-3.5-sonnet") - session = _FakeSession([search_space, byok_cfg]) + search_space = _make_search_space(chat_model_id=23, user_id=user_id) + byok_model = _make_byok_model(id_=23, base_model="anthropic/claude-3.5-sonnet") + session = _FakeSession([search_space, byok_model]) owner, tier, base_model = await _resolve_agent_billing_for_search_space( session, search_space_id=42 @@ -398,13 +233,10 @@ async def test_positive_id_byok_is_always_free(): @pytest.mark.asyncio async def test_positive_id_byok_missing_returns_free_with_empty_base_model(): - """If the BYOK config row is missing/deleted but the search space still - points at it, the resolver still returns free (no debit) with an empty - base_model — billable_call's premium path is skipped, no harm done.""" from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - session = _FakeSession([_make_search_space(agent_llm_id=99, user_id=user_id)]) + session = _FakeSession([_make_search_space(chat_model_id=99, user_id=user_id)]) owner, tier, base_model = await _resolve_agent_billing_for_search_space( session, search_space_id=42 @@ -419,18 +251,18 @@ async def test_positive_id_byok_missing_returns_free_with_empty_base_model(): async def test_search_space_not_found_raises_value_error(): from app.services.billable_calls import _resolve_agent_billing_for_search_space - session = _FakeSession([None]) - with pytest.raises(ValueError, match="Search space"): - await _resolve_agent_billing_for_search_space(session, search_space_id=999) + await _resolve_agent_billing_for_search_space( + _FakeSession([None]), search_space_id=999 + ) @pytest.mark.asyncio -async def test_agent_llm_id_none_raises_value_error(): +async def test_chat_model_id_none_raises_value_error(): from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - session = _FakeSession([_make_search_space(agent_llm_id=None, user_id=user_id)]) + session = _FakeSession([_make_search_space(chat_model_id=None, user_id=user_id)]) - with pytest.raises(ValueError, match="agent_llm_id"): + with pytest.raises(ValueError, match="chat_model_id"): await _resolve_agent_billing_for_search_space(session, search_space_id=42) diff --git a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py index d7eb32732..d7c12a6e0 100644 --- a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py +++ b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py @@ -32,8 +32,9 @@ class _FakeQuotaResult: class _FakeExecResult: - def __init__(self, thread): + def __init__(self, *, thread=None, scalars=None): self._thread = thread + self._scalars = scalars or [] def unique(self): return self @@ -41,19 +42,69 @@ class _FakeExecResult: def scalar_one_or_none(self): return self._thread + def scalars(self): + return SimpleNamespace(all=lambda: self._scalars) + class _FakeSession: - def __init__(self, thread): + def __init__(self, thread, *, models=None): self.thread = thread + self.models = models or [] self.commit_count = 0 + self.execute_count = 0 async def execute(self, _stmt): - return _FakeExecResult(self.thread) + self.execute_count += 1 + if self.execute_count == 1: + return _FakeExecResult(thread=self.thread) + return _FakeExecResult(scalars=self.models) async def commit(self): self.commit_count += 1 +def _set_global_llm_configs(monkeypatch, config, configs: list[dict]): + """Patch the new global model catalog shape from compact legacy cfg fixtures.""" + connections = [] + models = [] + for cfg in configs: + config_id = int(cfg["id"]) + connection_id = config_id - 100_000 + provider = cfg.get("provider") or cfg.get("litellm_provider") + model_name = cfg["model_name"] + connections.append( + { + "id": connection_id, + "provider": provider, + "scope": "GLOBAL", + "enabled": True, + } + ) + models.append( + { + "id": config_id, + "connection_id": connection_id, + "model_id": model_name, + "display_name": cfg.get("name") or model_name, + "supports_chat": cfg.get("supports_chat", True), + "supports_image_input": cfg.get("supports_image_input", True), + "supports_tools": cfg.get("supports_tools", True), + "supports_image_generation": cfg.get("supports_image_generation", False), + "capabilities_override": cfg.get("capabilities_override") or {}, + "billing_tier": cfg.get("billing_tier", "free"), + "catalog": { + "auto_pin_tier": cfg.get("auto_pin_tier"), + "quality_score": cfg.get("quality_score") + or cfg.get("quality_score_static"), + }, + } + ) + + monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", configs) + monkeypatch.setattr(config, "GLOBAL_CONNECTIONS", connections) + monkeypatch.setattr(config, "GLOBAL_MODELS", models) + + def _thread( *, search_space_id: int = 10, @@ -71,9 +122,9 @@ async def test_auto_first_turn_pins_one_model(monkeypatch): from app.config import config session = _FakeSession(_thread()) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ {"id": -2, "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1"}, { @@ -111,9 +162,9 @@ async def test_premium_eligible_auto_prefers_premium_over_free(monkeypatch): from app.config import config session = _FakeSession(_thread()) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -2, @@ -158,9 +209,9 @@ async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch): from app.config import config session = _FakeSession(_thread()) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -1, @@ -216,9 +267,9 @@ async def test_next_turn_reuses_existing_pin(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-1)) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -1, @@ -257,9 +308,9 @@ async def test_premium_eligible_auto_can_pin_premium(monkeypatch): from app.config import config session = _FakeSession(_thread()) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -1, @@ -295,9 +346,9 @@ async def test_premium_ineligible_auto_pins_free_only(monkeypatch): from app.config import config session = _FakeSession(_thread()) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -2, @@ -340,9 +391,9 @@ async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-1)) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -2, @@ -385,9 +436,9 @@ async def test_force_repin_free_switches_auto_premium_pin_to_free(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-1)) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -2, @@ -433,9 +484,9 @@ async def test_explicit_user_model_change_clears_pin(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-2)) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ {"id": -2, "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1"}, ], @@ -458,9 +509,9 @@ async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-999)) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ {"id": -2, "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1"}, ], @@ -487,7 +538,7 @@ async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch): # --------------------------------------------------------------------------- -# Quality-aware pin selection (Auto Fastest upgrade) +# Quality-aware pin selection (Auto upgrade) # --------------------------------------------------------------------------- @@ -498,9 +549,9 @@ async def test_health_gated_config_is_excluded_from_selection(monkeypatch): from app.config import config session = _FakeSession(_thread()) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -1, @@ -550,9 +601,9 @@ async def test_tier_a_locks_first_premium_user_skips_or(monkeypatch): from app.config import config session = _FakeSession(_thread()) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -1, @@ -602,9 +653,9 @@ async def test_tier_a_falls_through_to_or_when_a_pool_empty_for_user(monkeypatch from app.config import config session = _FakeSession(_thread()) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -1, @@ -676,9 +727,9 @@ async def test_top_k_picks_only_high_score_models(monkeypatch): "quality_score": 10, "health_gated": False, } - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [*high_score_cfgs, low_score_trap], ) @@ -723,9 +774,9 @@ async def test_pin_reuse_survives_health_gating_for_existing_pin(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-1)) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -1, @@ -775,9 +826,9 @@ async def test_pin_reuse_regression_existing_healthy_pin(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-1)) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -1, @@ -833,9 +884,9 @@ async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-1)) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -1, @@ -886,9 +937,9 @@ async def test_clearing_runtime_cooldown_restores_pin_reuse(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-1)) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -1, @@ -931,9 +982,9 @@ async def test_auto_pin_repin_excludes_previous_config_on_runtime_retry(monkeypa from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-1)) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -1, diff --git a/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py b/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py index 63aa934a3..5850dfe23 100644 --- a/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py +++ b/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py @@ -15,15 +15,19 @@ async def test_global_openrouter_image_gen_sets_explicit_api_base(): """The global-config branch forwards the explicit OpenRouter base.""" from app.routes import image_generation_routes - cfg = { + global_model = { "id": -20_001, - "name": "GPT Image 1 (OpenRouter)", - "litellm_provider": "openrouter", - "model_name": "openai/gpt-image-1", + "connection_id": -101, + "model_id": "openai/gpt-image-1", + "supports_image_generation": True, + "capabilities_override": {}, + } + global_connection = { + "id": -101, + "provider": "openrouter", "api_key": "sk-or-test", - "api_base": "https://openrouter.ai/api/v1", - "api_version": None, - "litellm_params": {}, + "base_url": "https://openrouter.ai/api/v1", + "extra": {}, } captured: dict = {} @@ -33,7 +37,7 @@ async def test_global_openrouter_image_gen_sets_explicit_api_base(): return MagicMock(model_dump=lambda: {"data": []}, _hidden_params={}) image_gen = MagicMock() - image_gen.image_generation_config_id = cfg["id"] + image_gen.image_gen_model_id = global_model["id"] image_gen.prompt = "test" image_gen.n = 1 image_gen.quality = None @@ -43,14 +47,19 @@ async def test_global_openrouter_image_gen_sets_explicit_api_base(): image_gen.model = None search_space = MagicMock() - search_space.image_generation_config_id = cfg["id"] + search_space.image_gen_model_id = global_model["id"] session = MagicMock() with ( patch.object( image_generation_routes, - "_get_global_image_gen_config", - return_value=cfg, + "_get_global_model", + return_value=global_model, + ), + patch.object( + image_generation_routes, + "_get_global_connection", + return_value=global_connection, ), patch.object( image_generation_routes, @@ -74,15 +83,19 @@ async def test_generate_image_tool_global_sets_explicit_api_base(): generate_image as gi_module, ) - cfg = { + global_model = { "id": -20_001, - "name": "GPT Image 1 (OpenRouter)", - "litellm_provider": "openrouter", - "model_name": "openai/gpt-image-1", + "connection_id": -101, + "model_id": "openai/gpt-image-1", + "supports_image_generation": True, + "capabilities_override": {}, + } + global_connection = { + "id": -101, + "provider": "openrouter", "api_key": "sk-or-test", - "api_base": "https://openrouter.ai/api/v1", - "api_version": None, - "litellm_params": {}, + "base_url": "https://openrouter.ai/api/v1", + "extra": {}, } captured: dict = {} @@ -98,7 +111,7 @@ async def test_generate_image_tool_global_sets_explicit_api_base(): search_space = MagicMock() search_space.id = 1 - search_space.image_generation_config_id = cfg["id"] + search_space.image_gen_model_id = global_model["id"] session_cm = AsyncMock() session = AsyncMock() @@ -121,7 +134,8 @@ async def test_generate_image_tool_global_sets_explicit_api_base(): with ( patch.object(gi_module, "shielded_async_session", return_value=session_cm), - patch.object(gi_module, "_get_global_image_gen_config", return_value=cfg), + patch.object(gi_module, "_get_global_model", return_value=global_model), + patch.object(gi_module, "_get_global_connection", return_value=global_connection), patch.object( gi_module, "aimage_generation", side_effect=fake_aimage_generation ), diff --git a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py index 9d4c1a04b..ee97aac4d 100644 --- a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py +++ b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py @@ -217,7 +217,7 @@ def test_generate_configs_drops_non_text_and_non_tool_models(): # --------------------------------------------------------------------------- -# _generate_image_gen_configs / _generate_vision_llm_configs +# _generate_image_gen_configs # --------------------------------------------------------------------------- @@ -263,7 +263,7 @@ def test_generate_image_gen_configs_filters_by_image_output(): # Each config must carry ``billing_tier`` for routing in image_generation_routes. for c in cfgs: assert c["billing_tier"] in {"free", "premium"} - assert c["litellm_provider"] == "openrouter" + assert c["provider"] == "openrouter" assert c[_OPENROUTER_DYNAMIC_MARKER] is True # Emit the OpenRouter base URL at source so every call path passes an # explicit api_base and cannot inherit a process-global endpoint. @@ -271,9 +271,7 @@ def test_generate_image_gen_configs_filters_by_image_output(): def test_generate_image_gen_configs_assigns_image_id_offset(): - """Image configs use a different id_offset (-20000) so their negative - IDs don't collide with chat configs (-10000) or vision configs (-30000). - """ + """Image configs use their own id_offset (-20000).""" from app.services.openrouter_integration_service import ( _generate_image_gen_configs, ) @@ -291,88 +289,3 @@ def test_generate_image_gen_configs_assigns_image_id_offset(): assert all(c["id"] < -20_000 + 1 for c in cfgs) assert all(c["id"] > -29_000_000 for c in cfgs) - -def test_generate_vision_llm_configs_filters_by_image_input_text_output(): - """Vision LLMs must accept image input AND emit text — pure image-gen - (no text out) and text-only (no image in) models are excluded. - """ - from app.services.openrouter_integration_service import ( - _generate_vision_llm_configs, - ) - - raw = [ - # GPT-4o: vision LLM (image in, text out) — must emit. - { - "id": "openai/gpt-4o", - "architecture": { - "input_modalities": ["text", "image"], - "output_modalities": ["text"], - }, - "context_length": 128_000, - "pricing": {"prompt": "0.000005", "completion": "0.000015"}, - }, - # Pure image generator — image *output*, no text out. Must NOT emit. - { - "id": "openai/gpt-image-1", - "architecture": { - "input_modalities": ["text"], - "output_modalities": ["image"], - }, - "context_length": 4_000, - "pricing": {"prompt": "0", "completion": "0"}, - }, - # Pure text model (no image in). Must NOT emit. - { - "id": "anthropic/claude-3-haiku", - "architecture": { - "input_modalities": ["text"], - "output_modalities": ["text"], - }, - "context_length": 200_000, - "pricing": {"prompt": "0.000001", "completion": "0.000005"}, - }, - ] - - cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE)) - names = {c["model_name"] for c in cfgs} - assert names == {"openai/gpt-4o"} - - cfg = cfgs[0] - assert cfg["billing_tier"] == "premium" - # Pricing carried inline so pricing_registration can register vision - # under ``openrouter/openai/gpt-4o`` even if the chat catalogue cache - # is cleared. - assert cfg["input_cost_per_token"] == pytest.approx(5e-6) - assert cfg["output_cost_per_token"] == pytest.approx(15e-6) - assert cfg[_OPENROUTER_DYNAMIC_MARKER] is True - # Emit the OpenRouter base URL at source so every call path passes an - # explicit api_base and cannot inherit a process-global endpoint. - assert cfg["api_base"] == "https://openrouter.ai/api/v1" - - -def test_generate_vision_llm_configs_drops_chat_only_filters(): - """A small-context vision model that doesn't advertise tool calling is - still a valid vision LLM for "describe this image" prompts. The chat - filters (``supports_tool_calling``, ``has_sufficient_context``) must - NOT be applied to vision emission. - """ - from app.services.openrouter_integration_service import ( - _generate_vision_llm_configs, - ) - - raw = [ - { - "id": "tiny/vision-mini", - "architecture": { - "input_modalities": ["text", "image"], - "output_modalities": ["text"], - }, - "supported_parameters": [], # no tools - "context_length": 4_000, # well below MIN_CONTEXT_LENGTH - "pricing": {"prompt": "0.0000001", "completion": "0.0000005"}, - } - ] - - cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE)) - assert len(cfgs) == 1 - assert cfgs[0]["model_name"] == "tiny/vision-mini" diff --git a/surfsense_backend/tests/unit/services/test_pricing_registration.py b/surfsense_backend/tests/unit/services/test_pricing_registration.py index c9adc6aac..ee2faf674 100644 --- a/surfsense_backend/tests/unit/services/test_pricing_registration.py +++ b/surfsense_backend/tests/unit/services/test_pricing_registration.py @@ -370,77 +370,3 @@ def test_register_continues_after_individual_failure(monkeypatch, caplog): assert any("custom-deployment" in payload for payload in successful_calls) -def test_vision_configs_registered_with_chat_shape(monkeypatch): - """``register_pricing_from_global_configs`` walks - ``GLOBAL_VISION_LLM_CONFIGS`` in addition to the chat configs so vision - calls (during indexing) bill correctly. Vision configs use the same - chat-shape token prices, but image-gen pricing is intentionally NOT - registered here (handled via ``response_cost`` in LiteLLM). - """ - from app.config import config - from app.services.pricing_registration import register_pricing_from_global_configs - - spy = _patch_register(monkeypatch) - _patch_openrouter_pricing( - monkeypatch, - {"openai/gpt-4o": {"prompt": "0.000005", "completion": "0.000015"}}, - ) - - # No chat configs — only vision. Proves the vision walk is a separate - # iteration, not piggy-backed on the chat list. - monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", []) - monkeypatch.setattr( - config, - "GLOBAL_VISION_LLM_CONFIGS", - [ - { - "id": -1, - "litellm_provider": "openrouter", - "model_name": "openai/gpt-4o", - "billing_tier": "premium", - "input_cost_per_token": 5e-6, - "output_cost_per_token": 15e-6, - } - ], - ) - - register_pricing_from_global_configs() - - assert "openrouter/openai/gpt-4o" in spy.all_keys - payload_value = spy.calls[0]["openrouter/openai/gpt-4o"] - assert payload_value["mode"] == "chat" - assert payload_value["litellm_provider"] == "openrouter" - assert payload_value["input_cost_per_token"] == pytest.approx(5e-6) - assert payload_value["output_cost_per_token"] == pytest.approx(15e-6) - - -def test_vision_with_inline_pricing_when_or_cache_missing(monkeypatch): - """If the OpenRouter pricing cache misses a vision model (different - catalogue surface), the vision walk falls back to inline - ``input_cost_per_token``/``output_cost_per_token`` on the cfg itself. - """ - from app.config import config - from app.services.pricing_registration import register_pricing_from_global_configs - - spy = _patch_register(monkeypatch) - _patch_openrouter_pricing(monkeypatch, {}) - - monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", []) - monkeypatch.setattr( - config, - "GLOBAL_VISION_LLM_CONFIGS", - [ - { - "id": -1, - "litellm_provider": "openrouter", - "model_name": "google/gemini-2.5-flash", - "billing_tier": "premium", - "input_cost_per_token": 1e-6, - "output_cost_per_token": 4e-6, - } - ], - ) - - register_pricing_from_global_configs() - - assert "openrouter/google/gemini-2.5-flash" in spy.all_keys diff --git a/surfsense_backend/tests/unit/services/test_quality_score.py b/surfsense_backend/tests/unit/services/test_quality_score.py index 369c8b8f3..cb3f7523a 100644 --- a/surfsense_backend/tests/unit/services/test_quality_score.py +++ b/surfsense_backend/tests/unit/services/test_quality_score.py @@ -1,4 +1,4 @@ -"""Unit tests for the Auto (Fastest) quality scoring module.""" +"""Unit tests for the Auto quality scoring module.""" from __future__ import annotations diff --git a/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py b/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py deleted file mode 100644 index 48dfc8e0b..000000000 --- a/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py +++ /dev/null @@ -1,77 +0,0 @@ -"""Vision LLM resolution must pass explicit per-config ``api_base``.""" - -from __future__ import annotations - -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -pytestmark = pytest.mark.unit - - -@pytest.mark.asyncio -async def test_get_vision_llm_global_openrouter_sets_api_base(): - """Global negative-ID branch forwards the explicit OpenRouter base.""" - from app.services import llm_service - - cfg = { - "id": -30_001, - "name": "GPT-4o Vision (OpenRouter)", - "litellm_provider": "openrouter", - "model_name": "openai/gpt-4o", - "api_key": "sk-or-test", - "api_base": "https://openrouter.ai/api/v1", - "api_version": None, - "litellm_params": {}, - "billing_tier": "free", - } - - search_space = MagicMock() - search_space.id = 1 - search_space.user_id = "user-x" - search_space.vision_llm_config_id = cfg["id"] - - session = AsyncMock() - scalars = MagicMock() - scalars.first.return_value = search_space - result = MagicMock() - result.scalars.return_value = scalars - session.execute.return_value = result - - captured: dict = {} - - class FakeSanitized: - def __init__(self, **kwargs): - captured.update(kwargs) - - with ( - patch( - "app.services.vision_llm_router_service.get_global_vision_llm_config", - return_value=cfg, - ), - patch( - "app.agents.chat.runtime.llm_config.SanitizedChatLiteLLM", - new=FakeSanitized, - ), - ): - await llm_service.get_vision_llm(session=session, search_space_id=1) - - assert captured.get("api_base") == "https://openrouter.ai/api/v1" - assert captured["model"] == "openrouter/openai/gpt-4o" - - -def test_vision_router_deployment_sets_api_base_when_config_empty(): - """Auto-mode vision router carries explicit api_base into deployments.""" - from app.services.vision_llm_router_service import VisionLLMRouterService - - deployment = VisionLLMRouterService._config_to_deployment( - { - "model_name": "openai/gpt-4o", - "litellm_provider": "openrouter", - "api_key": "sk-or-test", - "api_base": "https://openrouter.ai/api/v1", - } - ) - assert deployment is not None - assert deployment["litellm_params"]["api_base"] == "https://openrouter.ai/api/v1" - assert deployment["litellm_params"]["model"] == "openrouter/openai/gpt-4o" diff --git a/surfsense_evals/README.md b/surfsense_evals/README.md index c755c4de6..e6fc52ca1 100644 --- a/surfsense_evals/README.md +++ b/surfsense_evals/README.md @@ -77,7 +77,7 @@ The walkthrough above is `--scenario head-to-head` (default): both arms answer w | `symmetric-cheap` | `--provider-model` (cheap, text-only) | `--provider-model` (same) | Does pre-extracted image context let a non-vision LLM reason over image-heavy docs? | | `cost-arbitrage` | `--native-arm-model` (vision) | `--provider-model` (cheap) | How close does SurfSense get to a vision-native baseline at a fraction of per-query cost?| -In all three modes the **ingest-time** vision LLM is set on the SearchSpace's `vision_llm_config_id` (auto-picked from the strongest registered global OpenRouter vision config — `claude-sonnet-4.5` > `claude-opus-4.7` > `gpt-5` > `gemini-2.5-pro`, override with `--vision-llm `). What changes is which slug the *answering* models hit per arm. +In all three modes the **ingest-time** vision LLM is set on the SearchSpace's `vision_model_id` (auto-picked from the strongest registered global OpenRouter vision-capable model — `claude-sonnet-4.5` > `claude-opus-4.7` > `gpt-5` > `gemini-2.5-pro`, override with `--vision-llm `). What changes is which slug the *answering* models hit per arm. ### Ingest with vision, evaluate with a non-vision LLM (`symmetric-cheap`) @@ -118,7 +118,7 @@ python -m surfsense_evals report --suite medical Notes: - `cost-arbitrage` requires both `--provider-model` (the cheap SurfSense slug) AND `--native-arm-model `. -- `--vision-llm ` is optional; if omitted the harness queries `GET /api/v1/global-vision-llm-configs` and auto-picks the strongest registered one. Pass `--no-vision-llm-setup` if you want to keep whatever vision config is already attached to the SearchSpace. +- `--vision-llm ` is optional; if omitted the harness queries `GET /api/v1/model-connections/global` and auto-picks the strongest registered vision-capable model. Pass `--no-vision-llm-setup` if you want to keep whatever vision model is already attached to the SearchSpace. - The runner's "looks text-only" warning is suppressed (or relabelled as informational) for `symmetric-cheap` so intentional asymmetry doesn't read as a misconfiguration. - All three scenario fields (`scenario`, `provider_model`, `native_arm_model`, `vision_provider_model`) are persisted to `state.json` and recorded in `run_artifact.extra` + the report header — no need to retrace what was set. diff --git a/surfsense_evals/data/multimodal_doc/runs/2026-05-14T00-53-19Z/parser_compare/run_artifact.json b/surfsense_evals/data/multimodal_doc/runs/2026-05-14T00-53-19Z/parser_compare/run_artifact.json index a4687f64a..b6c59e2bc 100644 --- a/surfsense_evals/data/multimodal_doc/runs/2026-05-14T00-53-19Z/parser_compare/run_artifact.json +++ b/surfsense_evals/data/multimodal_doc/runs/2026-05-14T00-53-19Z/parser_compare/run_artifact.json @@ -9,7 +9,7 @@ "llamacloud_premium_lc", "surfsense_agentic" ], - "agent_llm_id": -5138454, + "chat_model_id": -5138454, "concurrency": 2, "llm_model": "anthropic/claude-sonnet-4.5", "n_pdfs": 30, diff --git a/surfsense_evals/src/surfsense_evals/core/cli.py b/surfsense_evals/src/surfsense_evals/core/cli.py index 3d4d0fd24..17979fba0 100644 --- a/surfsense_evals/src/surfsense_evals/core/cli.py +++ b/surfsense_evals/src/surfsense_evals/core/cli.py @@ -2,7 +2,7 @@ Subcommands: -* ``setup --suite --provider-model [--agent-llm-id ]`` +* ``setup --suite --provider-model [--chat-model-id ]`` * ``teardown --suite `` * ``models list [--provider openrouter] [--grep ]`` * ``suites list`` @@ -18,7 +18,7 @@ publish its own flags. Design choices worth flagging: -* ``setup`` rejects ``agent_llm_id == 0`` (Auto / LiteLLM router) so +* ``setup`` rejects ``chat_model_id == 0`` (Auto / LiteLLM router) so per-question accuracy is reproducible. * ``setup`` validates that the picked LLM config has ``provider == "OPENROUTER"`` and ``model_name == --provider-model`` @@ -59,7 +59,6 @@ if sys.platform == "win32": from . import registry from .auth import CredentialError, acquire_token, client_with_auth from .clients import SearchSpaceClient -from .clients.search_space import LlmPreferences from .config import ( DEFAULT_SCENARIO, SCENARIOS, @@ -111,23 +110,30 @@ class LlmConfigEntry: def from_payload(cls, payload: dict[str, Any]) -> LlmConfigEntry: return cls( id=int(payload["id"]), - name=str(payload.get("name", "")), + name=str(payload.get("display_name") or payload.get("name") or ""), provider=str(payload.get("provider", "")).upper(), - model_name=str(payload.get("model_name", "")), + model_name=str(payload.get("model_id") or payload.get("model_name") or ""), raw=payload, ) async def _list_global_llm_configs(http: httpx.AsyncClient, base: str) -> list[LlmConfigEntry]: response = await http.get( - f"{base}/api/v1/global-new-llm-configs", + f"{base}/api/v1/model-connections/global", headers={"Accept": "application/json"}, ) response.raise_for_status() payload = response.json() if not isinstance(payload, list): - raise RuntimeError(f"Unexpected /global-new-llm-configs payload: {payload!r}") - return [LlmConfigEntry.from_payload(item) for item in payload] + raise RuntimeError(f"Unexpected /model-connections/global payload: {payload!r}") + entries: list[LlmConfigEntry] = [] + for connection in payload: + provider = connection.get("provider", "") + for model in connection.get("models") or []: + if not model.get("enabled", True) or not model.get("supports_chat"): + continue + entries.append(LlmConfigEntry.from_payload({**model, "provider": provider})) + return entries def _resolve_openrouter_id( @@ -143,8 +149,8 @@ def _resolve_openrouter_id( * If ``explicit_id`` is given: return it directly. The caller is then expected to GET-validate that the row's ``provider == "OPENROUTER"`` and ``model_name`` matches the slug. - That branch supports positive BYOK ``NewLLMConfig`` rows whose - slugs may overlap with global OpenRouter virtuals. + That branch supports positive BYOK model rows whose slugs may overlap + with global OpenRouter virtuals. * Otherwise: filter to ``provider == "OPENROUTER"`` and ``model_name == provider_model``. Expect exactly one match — raise with a friendly message otherwise. @@ -173,7 +179,7 @@ def _resolve_openrouter_id( listing = "\n".join(f" id={c.id} name={c.name!r}" for c in matches) raise RuntimeError( f"Multiple OpenRouter configs for slug '{provider_model}':\n{listing}\n" - "Pass --agent-llm-id to disambiguate." + "Pass --chat-model-id to disambiguate." ) return matches[0].id @@ -186,7 +192,7 @@ def _resolve_openrouter_id( async def _cmd_setup(args: argparse.Namespace) -> int: suite = args.suite provider_model: str = args.provider_model - explicit_id: int | None = args.agent_llm_id + explicit_id: int | None = args.chat_model_id scenario: str = args.scenario vision_llm_slug: str | None = args.vision_llm native_arm_model: str | None = args.native_arm_model @@ -194,7 +200,7 @@ async def _cmd_setup(args: argparse.Namespace) -> int: if explicit_id == 0: console.print( - "[red]agent_llm_id == 0 (Auto / LiteLLM router) is not allowed — " + "[red]chat_model_id == 0 (Auto / LiteLLM router) is not allowed — " "results would not be reproducible.[/red]" ) return 2 @@ -242,7 +248,7 @@ async def _cmd_setup(args: argparse.Namespace) -> int: candidates = await _list_global_llm_configs(http, config.surfsense_api_base) try: - agent_llm_id = _resolve_openrouter_id( + chat_model_id = _resolve_openrouter_id( candidates, provider_model, explicit_id=explicit_id ) except RuntimeError as exc: @@ -288,7 +294,7 @@ async def _cmd_setup(args: argparse.Namespace) -> int: vision_provider_model: str | None = None if not skip_vision_setup and (vision_required or vision_llm_slug is not None): try: - vision_candidates = await ss_client.list_global_vision_llm_configs() + vision_candidates = await ss_client.list_global_vision_models() resolved = resolve_vision_llm( vision_candidates, explicit_slug=vision_llm_slug ) @@ -302,37 +308,34 @@ async def _cmd_setup(args: argparse.Namespace) -> int: f"(id={vision_config_id}, selected_via={resolved.selected_via})." ) - pref_kwargs: dict[str, Any] = {"agent_llm_id": agent_llm_id} + role_kwargs: dict[str, Any] = {"chat_model_id": chat_model_id} if vision_config_id is not None: - pref_kwargs["vision_llm_config_id"] = vision_config_id + role_kwargs["vision_model_id"] = vision_config_id - await ss_client.set_llm_preferences(search_space_id, **pref_kwargs) - prefs = await ss_client.get_llm_preferences(search_space_id) - if not _validate_pin(prefs, provider_model): - agent = prefs.agent_llm or {} + await ss_client.set_model_roles(search_space_id, **role_kwargs) + roles = await ss_client.get_model_roles(search_space_id) + if roles.chat_model_id != chat_model_id: console.print( f"[red]LLM pin validation FAILED.[/red] After PUT, " - f"agent_llm.provider={agent.get('provider')!r}, " - f"model_name={agent.get('model_name')!r}; expected " - f"provider=OPENROUTER, model_name={provider_model!r}." + f"chat_model_id={roles.chat_model_id!r}; expected {chat_model_id!r}." ) return 2 - if vision_config_id is not None and prefs.vision_llm_config_id != vision_config_id: + if vision_config_id is not None and roles.vision_model_id != vision_config_id: console.print( f"[red]Vision LLM pin validation FAILED.[/red] After PUT, " - f"vision_llm_config_id={prefs.vision_llm_config_id!r}; " + f"vision_model_id={roles.vision_model_id!r}; " f"expected {vision_config_id!r}." ) return 2 suite_state = SuiteState( search_space_id=search_space_id, - agent_llm_id=agent_llm_id, + chat_model_id=chat_model_id, provider_model=provider_model, created_at=utc_iso_timestamp(), ingestion_maps=existing.ingestion_maps if existing else {}, scenario=scenario, - vision_llm_config_id=vision_config_id, + vision_model_id=vision_config_id, vision_provider_model=vision_provider_model, native_arm_model=native_arm_model, ) @@ -342,7 +345,7 @@ async def _cmd_setup(args: argparse.Namespace) -> int: f"suite={suite!r}", f"scenario={scenario!r}", f"search_space_id={suite_state.search_space_id}", - f"agent_llm_id={suite_state.agent_llm_id}", + f"chat_model_id={suite_state.chat_model_id}", f"provider_model={suite_state.provider_model!r}", ] if suite_state.vision_provider_model: @@ -353,14 +356,6 @@ async def _cmd_setup(args: argparse.Namespace) -> int: return 0 -def _validate_pin(prefs: LlmPreferences, provider_model: str) -> bool: - agent = prefs.agent_llm or {} - return ( - str(agent.get("provider", "")).upper() == "OPENROUTER" - and str(agent.get("model_name", "")) == provider_model - ) - - async def _cmd_teardown(args: argparse.Namespace) -> int: suite = args.suite config = load_config() @@ -654,10 +649,10 @@ def _build_parser() -> argparse.ArgumentParser: ), ) p_setup.add_argument( - "--agent-llm-id", + "--chat-model-id", type=int, default=None, - help="Optional override for BYOK NewLLMConfig rows.", + help="Optional explicit model id override.", ) p_setup.add_argument( "--scenario", diff --git a/surfsense_evals/src/surfsense_evals/core/clients/search_space.py b/surfsense_evals/src/surfsense_evals/core/clients/search_space.py index e2d37694d..efd4a571d 100644 --- a/surfsense_evals/src/surfsense_evals/core/clients/search_space.py +++ b/surfsense_evals/src/surfsense_evals/core/clients/search_space.py @@ -1,17 +1,16 @@ -"""Client for ``/api/v1/searchspaces`` and ``/api/v1/search-spaces/{id}/llm-preferences``. +"""Client for ``/api/v1/searchspaces`` and model-role endpoints. Verified against: * ``surfsense_backend/app/routes/search_spaces_routes.py:116`` (POST create) * ``surfsense_backend/app/routes/search_spaces_routes.py:234`` (GET by id) * ``surfsense_backend/app/routes/search_spaces_routes.py:422`` (DELETE soft-delete) -* ``surfsense_backend/app/routes/search_spaces_routes.py:698-849`` (GET/PUT llm-preferences) +* ``surfsense_backend/app/routes/model_connections_routes.py`` (GET/PUT model roles) * ``surfsense_backend/app/schemas/search_space.py:14`` (SearchSpaceCreate body) -* ``surfsense_backend/app/routes/vision_llm_routes.py:60`` (GET global vision configs) Note the inconsistent pluralisation in the backend: ``/searchspaces`` -(no hyphen) for CRUD, but ``/search-spaces`` (hyphenated) for the -``llm-preferences`` sub-resource. Both are mirrored verbatim here. +(no hyphen) for CRUD, but ``/search-spaces`` (hyphenated) for model-role +sub-resources. Both are mirrored verbatim here. """ from __future__ import annotations @@ -46,13 +45,8 @@ class SearchSpaceRow: @dataclass -class VisionLlmConfigEntry: - """Subset of one ``GET /global-vision-llm-configs`` row. - - The backend returns negative ids for global / OpenRouter-derived - vision configs and positive ids for per-user BYOK rows. Either is - accepted by ``set_llm_preferences(vision_llm_config_id=...)``. - """ +class VisionModelEntry: + """Subset of one GLOBAL model-connection model with image input support.""" id: int name: str @@ -62,45 +56,38 @@ class VisionLlmConfigEntry: raw: dict[str, Any] @classmethod - def from_payload(cls, payload: dict[str, Any]) -> VisionLlmConfigEntry: + def from_payload(cls, payload: dict[str, Any]) -> VisionModelEntry: return cls( id=int(payload.get("id", 0)), - name=str(payload.get("name", "")), + name=str(payload.get("display_name") or payload.get("model_id") or ""), provider=str(payload.get("provider", "")).upper(), - model_name=str(payload.get("model_name", "")), - is_auto_mode=bool(payload.get("is_auto_mode", False)), + model_name=str(payload.get("model_id", "")), + is_auto_mode=False, raw=payload, ) @dataclass -class LlmPreferences: - """Resolved LLM preferences with the embedded full config row. +class ModelRoles: + """Model role ids for a search space.""" - Mirrors ``LLMPreferencesRead`` from the backend so the lifecycle - command can introspect ``provider`` / ``model_name`` to validate the - OpenRouter pin. - """ - - agent_llm_id: int | None - image_generation_config_id: int | None - vision_llm_config_id: int | None - agent_llm: dict[str, Any] | None + chat_model_id: int | None + image_gen_model_id: int | None + vision_model_id: int | None raw: dict[str, Any] @classmethod - def from_payload(cls, payload: dict[str, Any]) -> LlmPreferences: + def from_payload(cls, payload: dict[str, Any]) -> ModelRoles: return cls( - agent_llm_id=payload.get("agent_llm_id"), - image_generation_config_id=payload.get("image_generation_config_id"), - vision_llm_config_id=payload.get("vision_llm_config_id"), - agent_llm=payload.get("agent_llm"), + chat_model_id=payload.get("chat_model_id"), + image_gen_model_id=payload.get("image_gen_model_id"), + vision_model_id=payload.get("vision_model_id"), raw=payload, ) class SearchSpaceClient: - """Thin wrapper around the SearchSpace + LLM preferences endpoints.""" + """Thin wrapper around the SearchSpace + model role endpoints.""" def __init__(self, http: httpx.AsyncClient, base_url: str) -> None: self._http = http @@ -139,64 +126,67 @@ class SearchSpaceClient: return response.raise_for_status() - async def get_llm_preferences(self, search_space_id: int) -> LlmPreferences: + async def get_model_roles(self, search_space_id: int) -> ModelRoles: response = await self._http.get( - f"{self._base}/api/v1/search-spaces/{search_space_id}/llm-preferences", + f"{self._base}/api/v1/search-spaces/{search_space_id}/model-roles", headers={"Accept": "application/json"}, ) response.raise_for_status() - return LlmPreferences.from_payload(response.json()) + return ModelRoles.from_payload(response.json()) - async def set_llm_preferences( + async def set_model_roles( self, search_space_id: int, *, - agent_llm_id: int | None = None, - image_generation_config_id: int | None = None, - vision_llm_config_id: int | None = None, - ) -> LlmPreferences: - """PUT a partial update to ``/search-spaces/{id}/llm-preferences``. + chat_model_id: int | None = None, + image_gen_model_id: int | None = None, + vision_model_id: int | None = None, + ) -> ModelRoles: + """PUT a partial update to ``/search-spaces/{id}/model-roles``. Backend uses ``model_dump(exclude_unset=True)`` so omitted fields are left unchanged. """ body: dict[str, Any] = {} - if agent_llm_id is not None: - body["agent_llm_id"] = agent_llm_id - if image_generation_config_id is not None: - body["image_generation_config_id"] = image_generation_config_id - if vision_llm_config_id is not None: - body["vision_llm_config_id"] = vision_llm_config_id + if chat_model_id is not None: + body["chat_model_id"] = chat_model_id + if image_gen_model_id is not None: + body["image_gen_model_id"] = image_gen_model_id + if vision_model_id is not None: + body["vision_model_id"] = vision_model_id response = await self._http.put( - f"{self._base}/api/v1/search-spaces/{search_space_id}/llm-preferences", + f"{self._base}/api/v1/search-spaces/{search_space_id}/model-roles", json=body, headers={"Accept": "application/json"}, ) response.raise_for_status() - return LlmPreferences.from_payload(response.json()) + return ModelRoles.from_payload(response.json()) - async def list_global_vision_llm_configs(self) -> list[VisionLlmConfigEntry]: - """List the registered global vision LLM configs. + async def list_global_vision_models(self) -> list[VisionModelEntry]: + """List registered GLOBAL models that can accept image input. - Used by ``setup`` to (a) resolve an explicit ``--vision-llm `` - to a config id and (b) auto-pick the strongest registered vision - config when the operator doesn't pass one. The ``Auto (Fastest)`` - entry (``id=0``) is filtered out — accuracy must be reproducible. + Used by ``setup`` to resolve ``--vision-llm `` or auto-pick a + reproducible ingest-time vision model. """ response = await self._http.get( - f"{self._base}/api/v1/global-vision-llm-configs", + f"{self._base}/api/v1/model-connections/global", headers={"Accept": "application/json"}, ) response.raise_for_status() payload = response.json() if not isinstance(payload, list): raise RuntimeError( - f"Unexpected /global-vision-llm-configs payload: {payload!r}" + f"Unexpected /model-connections/global payload: {payload!r}" ) - return [ - VisionLlmConfigEntry.from_payload(item) - for item in payload - if not bool(item.get("is_auto_mode", False)) - ] + entries: list[VisionModelEntry] = [] + for connection in payload: + provider = str(connection.get("provider", "")) + for model in connection.get("models") or []: + if not model.get("enabled", True) or not model.get("supports_image_input"): + continue + entries.append( + VisionModelEntry.from_payload({**model, "provider": provider}) + ) + return entries diff --git a/surfsense_evals/src/surfsense_evals/core/config.py b/surfsense_evals/src/surfsense_evals/core/config.py index 164955914..9a5a71e89 100644 --- a/surfsense_evals/src/surfsense_evals/core/config.py +++ b/surfsense_evals/src/surfsense_evals/core/config.py @@ -147,35 +147,35 @@ class SuiteState: """Per-suite persisted state. ``provider_model`` is the slug pinned to the SearchSpace's - ``agent_llm`` — what answers SurfSense queries (and what the native + ``chat_model_id`` — what answers SurfSense queries (and what the native arm uses too, unless ``native_arm_model`` is set for cost-arbitrage). - ``vision_provider_model`` is the slug of the OpenRouter vision LLM - config attached to the SearchSpace's ``vision_llm_config_id`` — what + ``vision_provider_model`` is the slug of the OpenRouter vision model + attached to the SearchSpace's ``vision_model_id`` — what SurfSense uses to extract image content at ingest time when ``use_vision_llm=True``. ``None`` means no vision config was attached at setup (legacy or text-only suite). """ search_space_id: int - agent_llm_id: int + chat_model_id: int provider_model: str created_at: str ingestion_maps: dict[str, str] = field(default_factory=dict) scenario: str = DEFAULT_SCENARIO - vision_llm_config_id: int | None = None + vision_model_id: int | None = None vision_provider_model: str | None = None native_arm_model: str | None = None def to_dict(self) -> dict[str, Any]: return { "search_space_id": self.search_space_id, - "agent_llm_id": self.agent_llm_id, + "chat_model_id": self.chat_model_id, "provider_model": self.provider_model, "created_at": self.created_at, "ingestion_maps": dict(self.ingestion_maps), "scenario": self.scenario, - "vision_llm_config_id": self.vision_llm_config_id, + "vision_model_id": self.vision_model_id, "vision_provider_model": self.vision_provider_model, "native_arm_model": self.native_arm_model, } @@ -187,15 +187,16 @@ class SuiteState: scenario = str(payload.get("scenario") or DEFAULT_SCENARIO) if scenario not in SCENARIOS: scenario = DEFAULT_SCENARIO - raw_vision_id = payload.get("vision_llm_config_id") + raw_chat_id = payload.get("chat_model_id") + raw_vision_id = payload.get("vision_model_id") return cls( search_space_id=int(payload["search_space_id"]), - agent_llm_id=int(payload["agent_llm_id"]), + chat_model_id=int(raw_chat_id), provider_model=str(payload["provider_model"]), created_at=str(payload.get("created_at") or ""), ingestion_maps=dict(payload.get("ingestion_maps") or {}), scenario=scenario, - vision_llm_config_id=int(raw_vision_id) if raw_vision_id is not None else None, + vision_model_id=int(raw_vision_id) if raw_vision_id is not None else None, vision_provider_model=( str(payload["vision_provider_model"]) if payload.get("vision_provider_model") diff --git a/surfsense_evals/src/surfsense_evals/core/registry.py b/surfsense_evals/src/surfsense_evals/core/registry.py index cc8b725e0..65f64c39a 100644 --- a/surfsense_evals/src/surfsense_evals/core/registry.py +++ b/surfsense_evals/src/surfsense_evals/core/registry.py @@ -53,8 +53,8 @@ class RunContext: return self.suite_state.search_space_id @property - def agent_llm_id(self) -> int: - return self.suite_state.agent_llm_id + def chat_model_id(self) -> int: + return self.suite_state.chat_model_id @property def provider_model(self) -> str: diff --git a/surfsense_evals/src/surfsense_evals/core/vision_llm.py b/surfsense_evals/src/surfsense_evals/core/vision_llm.py index ae96f1285..5d5e2c6d1 100644 --- a/surfsense_evals/src/surfsense_evals/core/vision_llm.py +++ b/surfsense_evals/src/surfsense_evals/core/vision_llm.py @@ -3,8 +3,8 @@ Two responsibilities: 1. Resolve an explicit ``--vision-llm `` to a global OpenRouter - vision LLM config id that ``set_llm_preferences(vision_llm_config_id=...)`` - can accept. + vision-capable model id that ``set_model_roles(vision_model_id=...)`` can + accept. 2. Auto-pick the strongest registered vision config when the operator doesn't pass ``--vision-llm`` but the scenario / benchmark needs one. diff --git a/surfsense_evals/src/surfsense_evals/suites/medical/medxpertqa/runner.py b/surfsense_evals/src/surfsense_evals/suites/medical/medxpertqa/runner.py index e1a830138..ac0651996 100644 --- a/surfsense_evals/src/surfsense_evals/suites/medical/medxpertqa/runner.py +++ b/surfsense_evals/src/surfsense_evals/suites/medical/medxpertqa/runner.py @@ -371,7 +371,7 @@ class MedXpertQAMMBenchmark: "provider_model": ctx.provider_model, "native_arm_model": native_arm_model, "vision_provider_model": ctx.vision_provider_model, - "agent_llm_id": ctx.agent_llm_id, + "chat_model_id": ctx.chat_model_id, "ingest_settings": ingest_settings, }, ) diff --git a/surfsense_evals/src/surfsense_evals/suites/multimodal_doc/mmlongbench/runner.py b/surfsense_evals/src/surfsense_evals/suites/multimodal_doc/mmlongbench/runner.py index 95a1e15eb..b7685766e 100644 --- a/surfsense_evals/src/surfsense_evals/suites/multimodal_doc/mmlongbench/runner.py +++ b/surfsense_evals/src/surfsense_evals/suites/multimodal_doc/mmlongbench/runner.py @@ -391,7 +391,7 @@ class MMLongBenchDocBenchmark: "provider_model": ctx.provider_model, "native_arm_model": native_arm_model, "vision_provider_model": ctx.vision_provider_model, - "agent_llm_id": ctx.agent_llm_id, + "chat_model_id": ctx.chat_model_id, "ingest_settings": ingest_settings, }, ) diff --git a/surfsense_evals/src/surfsense_evals/suites/multimodal_doc/parser_compare/runner.py b/surfsense_evals/src/surfsense_evals/suites/multimodal_doc/parser_compare/runner.py index e71dffa65..2c4a0ffe4 100644 --- a/surfsense_evals/src/surfsense_evals/suites/multimodal_doc/parser_compare/runner.py +++ b/surfsense_evals/src/surfsense_evals/suites/multimodal_doc/parser_compare/runner.py @@ -554,7 +554,7 @@ class ParserCompareBenchmark: "scenario": ctx.scenario, "provider_model": ctx.provider_model, "vision_provider_model": ctx.vision_provider_model, - "agent_llm_id": ctx.agent_llm_id, + "chat_model_id": ctx.chat_model_id, "preprocess_tariff": { "basic_per_1k_pages": 1.0, "premium_per_1k_pages": 10.0, diff --git a/surfsense_evals/src/surfsense_evals/suites/research/crag/runner.py b/surfsense_evals/src/surfsense_evals/suites/research/crag/runner.py index 8b759e0d8..654c261a2 100644 --- a/surfsense_evals/src/surfsense_evals/suites/research/crag/runner.py +++ b/surfsense_evals/src/surfsense_evals/suites/research/crag/runner.py @@ -467,7 +467,7 @@ class CragBenchmark: "provider_model": ctx.provider_model, "native_arm_model": ctx.native_arm_model, "vision_provider_model": ctx.vision_provider_model, - "agent_llm_id": ctx.agent_llm_id, + "chat_model_id": ctx.chat_model_id, "ingest_settings": ingest_settings, "per_page_char_cap": per_page_char_cap, "max_output_tokens": max_output_tokens, diff --git a/surfsense_evals/src/surfsense_evals/suites/research/frames/runner.py b/surfsense_evals/src/surfsense_evals/suites/research/frames/runner.py index 9c0e16b00..450c7ddd6 100644 --- a/surfsense_evals/src/surfsense_evals/suites/research/frames/runner.py +++ b/surfsense_evals/src/surfsense_evals/suites/research/frames/runner.py @@ -372,7 +372,7 @@ class FramesBenchmark: "provider_model": ctx.provider_model, "native_arm_model": ctx.native_arm_model, "vision_provider_model": ctx.vision_provider_model, - "agent_llm_id": ctx.agent_llm_id, + "chat_model_id": ctx.chat_model_id, "ingest_settings": ingest_settings, "bare_arm_label": "bare_llm", }, diff --git a/surfsense_evals/tests/core/test_clients.py b/surfsense_evals/tests/core/test_clients.py index 611408703..aa98f0ad4 100644 --- a/surfsense_evals/tests/core/test_clients.py +++ b/surfsense_evals/tests/core/test_clients.py @@ -63,29 +63,22 @@ async def test_delete_search_space_idempotent_on_404(respx_mock, http): @pytest.mark.asyncio @respx.mock(base_url=_BASE) -async def test_set_llm_preferences_partial_update(respx_mock, http): - route = respx_mock.put("/api/v1/search-spaces/42/llm-preferences").mock( +async def test_set_model_roles_partial_update(respx_mock, http): + route = respx_mock.put("/api/v1/search-spaces/42/model-roles").mock( return_value=httpx.Response( 200, json={ - "agent_llm_id": -10042, - "agent_llm_id": None, - "image_generation_config_id": None, - "vision_llm_config_id": None, - "agent_llm": { - "id": -10042, - "provider": "OPENROUTER", - "model_name": "anthropic/claude-sonnet-4.5", - }, + "chat_model_id": -10042, + "image_gen_model_id": None, + "vision_model_id": None, }, ) ) client = SearchSpaceClient(http, _BASE) - prefs = await client.set_llm_preferences(42, agent_llm_id=-10042) - assert prefs.agent_llm_id == -10042 - assert prefs.agent_llm["provider"] == "OPENROUTER" + roles = await client.set_model_roles(42, chat_model_id=-10042) + assert roles.chat_model_id == -10042 sent_body = json.loads(route.calls[-1].request.content) - assert sent_body == {"agent_llm_id": -10042} + assert sent_body == {"chat_model_id": -10042} # --------------------------------------------------------------------------- diff --git a/surfsense_evals/tests/core/test_config.py b/surfsense_evals/tests/core/test_config.py index f7b8f7249..6f9671c86 100644 --- a/surfsense_evals/tests/core/test_config.py +++ b/surfsense_evals/tests/core/test_config.py @@ -41,14 +41,14 @@ def test_state_roundtrip_per_suite(tmp_env): # noqa: ARG001 assert get_suite_state(config, "medical") is None state = SuiteState( search_space_id=1, - agent_llm_id=-10042, + chat_model_id=-10042, provider_model="anthropic/claude-sonnet-4.5", created_at="2026-05-11T20-30-00Z", ) set_suite_state(config, "medical", state) legal = SuiteState( search_space_id=2, - agent_llm_id=-1, + chat_model_id=-1, provider_model="openai/gpt-5", created_at="2026-05-11T21-00-00Z", ) @@ -84,25 +84,19 @@ def test_paths_are_per_suite(tmp_env): # noqa: ARG001 # --------------------------------------------------------------------------- -def test_legacy_state_back_compat_defaults_to_head_to_head(): - """state.json files written before scenarios shipped must still load. +def test_minimal_state_defaults_to_head_to_head(): + """Missing scenario / vision / native fields default safely.""" - Missing ``scenario`` / ``vision_*`` / ``native_arm_model`` keys all - default to ``head-to-head`` / ``None`` so old setups keep working - after upgrade — the runner's behaviour exactly mirrors the legacy - one (both arms answer with ``provider_model``). - """ - - legacy = { + payload = { "search_space_id": 7, - "agent_llm_id": -123, + "chat_model_id": -123, "provider_model": "anthropic/claude-sonnet-4.5", "created_at": "2026-05-11T20-30-00Z", "ingestion_maps": {}, } - state = SuiteState.from_dict(legacy) + state = SuiteState.from_dict(payload) assert state.scenario == DEFAULT_SCENARIO == "head-to-head" - assert state.vision_llm_config_id is None + assert state.vision_model_id is None assert state.vision_provider_model is None assert state.native_arm_model is None # The native arm should still answer with the same slug as SurfSense. @@ -118,7 +112,7 @@ def test_unknown_scenario_falls_back_to_default(): payload = { "search_space_id": 1, - "agent_llm_id": -1, + "chat_model_id": -1, "provider_model": "openai/gpt-5", "scenario": "unknown-scenario-name", } @@ -130,11 +124,11 @@ def test_cost_arbitrage_state_persists_native_arm_model(tmp_env): # noqa: ARG00 config = load_config() state = SuiteState( search_space_id=42, - agent_llm_id=-1, + chat_model_id=-1, provider_model="openai/gpt-5.4-mini", created_at="2026-05-11T20-30-00Z", scenario="cost-arbitrage", - vision_llm_config_id=-101, + vision_model_id=-101, vision_provider_model="anthropic/claude-sonnet-4.5", native_arm_model="anthropic/claude-sonnet-4.5", ) @@ -142,7 +136,7 @@ def test_cost_arbitrage_state_persists_native_arm_model(tmp_env): # noqa: ARG00 fetched = get_suite_state(config, "medical") assert fetched.scenario == "cost-arbitrage" - assert fetched.vision_llm_config_id == -101 + assert fetched.vision_model_id == -101 assert fetched.vision_provider_model == "anthropic/claude-sonnet-4.5" assert fetched.native_arm_model == "anthropic/claude-sonnet-4.5" # Cost arbitrage's whole point: native arm slug != surfsense slug. diff --git a/surfsense_evals/tests/test_integration_smoke.py b/surfsense_evals/tests/test_integration_smoke.py index 493c04c25..1c89ae5ab 100644 --- a/surfsense_evals/tests/test_integration_smoke.py +++ b/surfsense_evals/tests/test_integration_smoke.py @@ -27,7 +27,7 @@ async def test_smoke_against_localhost(): pytest.skip("No credentials in environment; skipping integration smoke") bundle = await acquire_token(config) async with client_with_auth(config, bundle) as client: - response = await client.get(f"{config.surfsense_api_base}/api/v1/global-new-llm-configs") + response = await client.get(f"{config.surfsense_api_base}/api/v1/model-connections/global") try: response.raise_for_status() except httpx.HTTPStatusError as exc: diff --git a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/image-models/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/image-models/page.tsx deleted file mode 100644 index b300f8078..000000000 --- a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/image-models/page.tsx +++ /dev/null @@ -1,6 +0,0 @@ -import { ImageModelManager } from "@/components/settings/image-model-manager"; - -export default async function Page({ params }: { params: Promise<{ search_space_id: string }> }) { - const { search_space_id } = await params; - return ; -} diff --git a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/roles/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/roles/page.tsx deleted file mode 100644 index 5bad50cd3..000000000 --- a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/roles/page.tsx +++ /dev/null @@ -1,6 +0,0 @@ -import { LLMRoleManager } from "@/components/settings/llm-role-manager"; - -export default async function Page({ params }: { params: Promise<{ search_space_id: string }> }) { - const { search_space_id } = await params; - return ; -} diff --git a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/vision-models/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/vision-models/page.tsx deleted file mode 100644 index 06aea003a..000000000 --- a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/vision-models/page.tsx +++ /dev/null @@ -1,6 +0,0 @@ -import { VisionModelManager } from "@/components/settings/vision-model-manager"; - -export default async function Page({ params }: { params: Promise<{ search_space_id: string }> }) { - const { search_space_id } = await params; - return ; -} diff --git a/surfsense_web/atoms/image-gen-config/image-gen-config-mutation.atoms.ts b/surfsense_web/atoms/image-gen-config/image-gen-config-mutation.atoms.ts deleted file mode 100644 index 922c398c9..000000000 --- a/surfsense_web/atoms/image-gen-config/image-gen-config-mutation.atoms.ts +++ /dev/null @@ -1,96 +0,0 @@ -import { atomWithMutation } from "jotai-tanstack-query"; -import { toast } from "sonner"; -import type { - CreateImageGenConfigRequest, - CreateImageGenConfigResponse, - DeleteImageGenConfigResponse, - GetImageGenConfigsResponse, - UpdateImageGenConfigRequest, - UpdateImageGenConfigResponse, -} from "@/contracts/types/new-llm-config.types"; -import { imageGenConfigApiService } from "@/lib/apis/image-gen-config-api.service"; -import { cacheKeys } from "@/lib/query-client/cache-keys"; -import { queryClient } from "@/lib/query-client/client"; -import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; - -/** - * Mutation atom for creating a new ImageGenerationConfig - */ -export const createImageGenConfigMutationAtom = atomWithMutation((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - mutationKey: ["image-gen-configs", "create"], - meta: { suppressGlobalErrorToast: true }, - enabled: !!searchSpaceId, - mutationFn: async (request: CreateImageGenConfigRequest) => { - return imageGenConfigApiService.createConfig(request); - }, - onSuccess: (_: CreateImageGenConfigResponse, request: CreateImageGenConfigRequest) => { - toast.success(`${request.name} created`); - queryClient.invalidateQueries({ - queryKey: cacheKeys.imageGenConfigs.all(Number(searchSpaceId)), - }); - }, - onError: (error: Error) => { - toast.error(error.message || "Failed to create image model"); - }, - }; -}); - -/** - * Mutation atom for updating an existing ImageGenerationConfig - */ -export const updateImageGenConfigMutationAtom = atomWithMutation((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - mutationKey: ["image-gen-configs", "update"], - meta: { suppressGlobalErrorToast: true }, - enabled: !!searchSpaceId, - mutationFn: async (request: UpdateImageGenConfigRequest) => { - return imageGenConfigApiService.updateConfig(request); - }, - onSuccess: (_: UpdateImageGenConfigResponse, request: UpdateImageGenConfigRequest) => { - toast.success(`${request.data.name ?? "Configuration"} updated`); - queryClient.invalidateQueries({ - queryKey: cacheKeys.imageGenConfigs.all(Number(searchSpaceId)), - }); - queryClient.invalidateQueries({ - queryKey: cacheKeys.imageGenConfigs.byId(request.id), - }); - }, - onError: (error: Error) => { - toast.error(error.message || "Failed to update image model"); - }, - }; -}); - -/** - * Mutation atom for deleting an ImageGenerationConfig - */ -export const deleteImageGenConfigMutationAtom = atomWithMutation((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - mutationKey: ["image-gen-configs", "delete"], - meta: { suppressGlobalErrorToast: true }, - enabled: !!searchSpaceId, - mutationFn: async (request: { id: number; name: string }) => { - return imageGenConfigApiService.deleteConfig(request.id); - }, - onSuccess: (_: DeleteImageGenConfigResponse, request: { id: number; name: string }) => { - toast.success(`${request.name} deleted`); - queryClient.setQueryData( - cacheKeys.imageGenConfigs.all(Number(searchSpaceId)), - (oldData: GetImageGenConfigsResponse | undefined) => { - if (!oldData) return oldData; - return oldData.filter((config) => config.id !== request.id); - } - ); - }, - onError: (error: Error) => { - toast.error(error.message || "Failed to delete image model"); - }, - }; -}); diff --git a/surfsense_web/atoms/image-gen-config/image-gen-config-query.atoms.ts b/surfsense_web/atoms/image-gen-config/image-gen-config-query.atoms.ts deleted file mode 100644 index a45e69a03..000000000 --- a/surfsense_web/atoms/image-gen-config/image-gen-config-query.atoms.ts +++ /dev/null @@ -1,33 +0,0 @@ -import { atomWithQuery } from "jotai-tanstack-query"; -import { imageGenConfigApiService } from "@/lib/apis/image-gen-config-api.service"; -import { cacheKeys } from "@/lib/query-client/cache-keys"; -import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; - -/** - * Query atom for fetching user-created image gen configs for the active search space - */ -export const imageGenConfigsAtom = atomWithQuery((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - queryKey: cacheKeys.imageGenConfigs.all(Number(searchSpaceId)), - enabled: !!searchSpaceId, - staleTime: 5 * 60 * 1000, // 5 minutes - queryFn: async () => { - return imageGenConfigApiService.getConfigs(Number(searchSpaceId)); - }, - }; -}); - -/** - * Query atom for fetching global image gen configs (from YAML, negative IDs) - */ -export const globalImageGenConfigsAtom = atomWithQuery(() => { - return { - queryKey: cacheKeys.imageGenConfigs.global(), - staleTime: 10 * 60 * 1000, // 10 minutes - global configs rarely change - queryFn: async () => { - return imageGenConfigApiService.getGlobalConfigs(); - }, - }; -}); diff --git a/surfsense_web/atoms/new-llm-config/new-llm-config-mutation.atoms.ts b/surfsense_web/atoms/new-llm-config/new-llm-config-mutation.atoms.ts deleted file mode 100644 index 476d89d4c..000000000 --- a/surfsense_web/atoms/new-llm-config/new-llm-config-mutation.atoms.ts +++ /dev/null @@ -1,132 +0,0 @@ -import { atomWithMutation } from "jotai-tanstack-query"; -import { toast } from "sonner"; -import type { - CreateNewLLMConfigRequest, - CreateNewLLMConfigResponse, - DeleteNewLLMConfigRequest, - DeleteNewLLMConfigResponse, - GetNewLLMConfigsResponse, - UpdateLLMPreferencesRequest, - UpdateNewLLMConfigRequest, - UpdateNewLLMConfigResponse, -} from "@/contracts/types/new-llm-config.types"; -import { newLLMConfigApiService } from "@/lib/apis/new-llm-config-api.service"; -import { cacheKeys } from "@/lib/query-client/cache-keys"; -import { queryClient } from "@/lib/query-client/client"; -import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; - -/** - * Mutation atom for creating a new NewLLMConfig - */ -export const createNewLLMConfigMutationAtom = atomWithMutation((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - mutationKey: ["new-llm-configs", "create"], - meta: { suppressGlobalErrorToast: true }, - enabled: !!searchSpaceId, - mutationFn: async (request: CreateNewLLMConfigRequest) => { - return newLLMConfigApiService.createConfig(request); - }, - onSuccess: (_: CreateNewLLMConfigResponse, request: CreateNewLLMConfigRequest) => { - toast.success(`${request.name} created`); - queryClient.invalidateQueries({ - queryKey: cacheKeys.newLLMConfigs.all(Number(searchSpaceId)), - }); - }, - onError: (error: Error) => { - toast.error(error.message || "Failed to create model"); - }, - }; -}); - -/** - * Mutation atom for updating an existing NewLLMConfig - */ -export const updateNewLLMConfigMutationAtom = atomWithMutation((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - mutationKey: ["new-llm-configs", "update"], - meta: { suppressGlobalErrorToast: true }, - enabled: !!searchSpaceId, - mutationFn: async (request: UpdateNewLLMConfigRequest) => { - return newLLMConfigApiService.updateConfig(request); - }, - onSuccess: (_: UpdateNewLLMConfigResponse, request: UpdateNewLLMConfigRequest) => { - toast.success(`${request.data.name ?? "Configuration"} updated`); - queryClient.invalidateQueries({ - queryKey: cacheKeys.newLLMConfigs.all(Number(searchSpaceId)), - }); - queryClient.invalidateQueries({ - queryKey: cacheKeys.newLLMConfigs.byId(request.id), - }); - }, - onError: (error: Error) => { - toast.error(error.message || "Failed to update"); - }, - }; -}); - -/** - * Mutation atom for deleting a NewLLMConfig - */ -export const deleteNewLLMConfigMutationAtom = atomWithMutation((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - mutationKey: ["new-llm-configs", "delete"], - meta: { suppressGlobalErrorToast: true }, - enabled: !!searchSpaceId, - mutationFn: async (request: DeleteNewLLMConfigRequest & { name: string }) => { - return newLLMConfigApiService.deleteConfig({ id: request.id }); - }, - onSuccess: ( - _: DeleteNewLLMConfigResponse, - request: DeleteNewLLMConfigRequest & { name: string } - ) => { - toast.success(`${request.name} deleted`); - queryClient.setQueryData( - cacheKeys.newLLMConfigs.all(Number(searchSpaceId)), - (oldData: GetNewLLMConfigsResponse | undefined) => { - if (!oldData) return oldData; - return oldData.filter((config) => config.id !== request.id); - } - ); - }, - onError: (error: Error) => { - toast.error(error.message || "Failed to delete"); - }, - }; -}); - -/** - * Mutation atom for updating LLM preferences (role assignments) - */ -export const updateLLMPreferencesMutationAtom = atomWithMutation((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - mutationKey: ["llm-preferences", "update"], - meta: { suppressGlobalErrorToast: true }, - enabled: !!searchSpaceId, - mutationFn: async (request: UpdateLLMPreferencesRequest) => { - return newLLMConfigApiService.updateLLMPreferences(request); - }, - onSuccess: (_data, request: UpdateLLMPreferencesRequest) => { - queryClient.setQueryData( - cacheKeys.newLLMConfigs.preferences(Number(searchSpaceId)), - (old: Record | undefined) => ({ ...old, ...request.data }) - ); - // Automation eligibility is derived from these model preferences - // (agent/image/vision). Invalidate it so the automations gate alert - // reflects the new selection without a manual refresh. - queryClient.invalidateQueries({ - queryKey: cacheKeys.automations.modelEligibility(Number(searchSpaceId)), - }); - }, - onError: (error: Error) => { - toast.error(error.message || "Failed to update LLM preferences"); - }, - }; -}); diff --git a/surfsense_web/atoms/new-llm-config/new-llm-config-query.atoms.ts b/surfsense_web/atoms/new-llm-config/new-llm-config-query.atoms.ts deleted file mode 100644 index 410d061e5..000000000 --- a/surfsense_web/atoms/new-llm-config/new-llm-config-query.atoms.ts +++ /dev/null @@ -1,98 +0,0 @@ -import { atomWithQuery } from "jotai-tanstack-query"; -import type { LLMModel } from "@/contracts/enums/llm-models"; -import { LLM_MODELS } from "@/contracts/enums/llm-models"; -import { newLLMConfigApiService } from "@/lib/apis/new-llm-config-api.service"; -import { getBearerToken } from "@/lib/auth-utils"; -import { cacheKeys } from "@/lib/query-client/cache-keys"; -import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; - -/** - * Query atom for fetching all NewLLMConfigs for the active search space - */ -export const newLLMConfigsAtom = atomWithQuery((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - queryKey: cacheKeys.newLLMConfigs.all(Number(searchSpaceId)), - enabled: !!searchSpaceId, - staleTime: 5 * 60 * 1000, // 5 minutes - queryFn: async () => { - return newLLMConfigApiService.getConfigs({ - search_space_id: Number(searchSpaceId), - }); - }, - }; -}); - -/** - * Query atom for fetching global NewLLMConfigs (from YAML, negative IDs) - */ -export const globalNewLLMConfigsAtom = atomWithQuery(() => { - return { - queryKey: cacheKeys.newLLMConfigs.global(), - staleTime: 10 * 60 * 1000, // 10 minutes - global configs rarely change - enabled: !!getBearerToken(), - queryFn: async () => { - return newLLMConfigApiService.getGlobalConfigs(); - }, - }; -}); - -/** - * Query atom for fetching LLM preferences (role assignments) for the active search space - */ -export const llmPreferencesAtom = atomWithQuery((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - queryKey: cacheKeys.newLLMConfigs.preferences(Number(searchSpaceId)), - enabled: !!searchSpaceId, - staleTime: 5 * 60 * 1000, // 5 minutes - queryFn: async () => { - return newLLMConfigApiService.getLLMPreferences(Number(searchSpaceId)); - }, - }; -}); - -/** - * Query atom for fetching default system instructions template - */ -export const defaultSystemInstructionsAtom = atomWithQuery(() => { - return { - queryKey: cacheKeys.newLLMConfigs.defaultInstructions(), - staleTime: 60 * 60 * 1000, // 1 hour - this rarely changes - queryFn: async () => { - return newLLMConfigApiService.getDefaultSystemInstructions(); - }, - }; -}); - -/** - * Query atom for the dynamic model catalogue. - * Fetched from the backend (which proxies OpenRouter's public API). - * Falls back to the static hardcoded list on error. - */ -export const modelListAtom = atomWithQuery(() => { - return { - queryKey: cacheKeys.newLLMConfigs.modelList(), - staleTime: 60 * 60 * 1000, // 1 hour - models don't change often - placeholderData: LLM_MODELS, - queryFn: async (): Promise => { - const data = await newLLMConfigApiService.getModels(); - const dynamicModels = data.map((m) => ({ - value: m.value, - label: m.label, - provider: m.provider, - contextWindow: m.context_window ?? undefined, - })); - - // Providers covered by the dynamic API (from OpenRouter mapping). - // For uncovered providers (Ollama, Groq, Bedrock, etc.) keep the - // hand-curated static suggestions so users still see model options. - const coveredProviders = new Set(dynamicModels.map((m) => m.provider)); - const staticFallbacks = LLM_MODELS.filter((m) => !coveredProviders.has(m.provider)); - - return [...dynamicModels, ...staticFallbacks]; - }, - }; -}); diff --git a/surfsense_web/atoms/vision-llm-config/vision-llm-config-mutation.atoms.ts b/surfsense_web/atoms/vision-llm-config/vision-llm-config-mutation.atoms.ts deleted file mode 100644 index f46b977d5..000000000 --- a/surfsense_web/atoms/vision-llm-config/vision-llm-config-mutation.atoms.ts +++ /dev/null @@ -1,87 +0,0 @@ -import { atomWithMutation } from "jotai-tanstack-query"; -import { toast } from "sonner"; -import type { - CreateVisionLLMConfigRequest, - CreateVisionLLMConfigResponse, - DeleteVisionLLMConfigResponse, - GetVisionLLMConfigsResponse, - UpdateVisionLLMConfigRequest, - UpdateVisionLLMConfigResponse, -} from "@/contracts/types/new-llm-config.types"; -import { visionLLMConfigApiService } from "@/lib/apis/vision-llm-config-api.service"; -import { cacheKeys } from "@/lib/query-client/cache-keys"; -import { queryClient } from "@/lib/query-client/client"; -import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; - -export const createVisionLLMConfigMutationAtom = atomWithMutation((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - mutationKey: ["vision-llm-configs", "create"], - meta: { suppressGlobalErrorToast: true }, - enabled: !!searchSpaceId, - mutationFn: async (request: CreateVisionLLMConfigRequest) => { - return visionLLMConfigApiService.createConfig(request); - }, - onSuccess: (_: CreateVisionLLMConfigResponse, request: CreateVisionLLMConfigRequest) => { - toast.success(`${request.name} created`); - queryClient.invalidateQueries({ - queryKey: cacheKeys.visionLLMConfigs.all(Number(searchSpaceId)), - }); - }, - onError: (error: Error) => { - toast.error(error.message || "Failed to create vision model"); - }, - }; -}); - -export const updateVisionLLMConfigMutationAtom = atomWithMutation((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - mutationKey: ["vision-llm-configs", "update"], - meta: { suppressGlobalErrorToast: true }, - enabled: !!searchSpaceId, - mutationFn: async (request: UpdateVisionLLMConfigRequest) => { - return visionLLMConfigApiService.updateConfig(request); - }, - onSuccess: (_: UpdateVisionLLMConfigResponse, request: UpdateVisionLLMConfigRequest) => { - toast.success(`${request.data.name ?? "Configuration"} updated`); - queryClient.invalidateQueries({ - queryKey: cacheKeys.visionLLMConfigs.all(Number(searchSpaceId)), - }); - queryClient.invalidateQueries({ - queryKey: cacheKeys.visionLLMConfigs.byId(request.id), - }); - }, - onError: (error: Error) => { - toast.error(error.message || "Failed to update vision model"); - }, - }; -}); - -export const deleteVisionLLMConfigMutationAtom = atomWithMutation((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - mutationKey: ["vision-llm-configs", "delete"], - meta: { suppressGlobalErrorToast: true }, - enabled: !!searchSpaceId, - mutationFn: async (request: { id: number; name: string }) => { - return visionLLMConfigApiService.deleteConfig(request.id); - }, - onSuccess: (_: DeleteVisionLLMConfigResponse, request: { id: number; name: string }) => { - toast.success(`${request.name} deleted`); - queryClient.setQueryData( - cacheKeys.visionLLMConfigs.all(Number(searchSpaceId)), - (oldData: GetVisionLLMConfigsResponse | undefined) => { - if (!oldData) return oldData; - return oldData.filter((config) => config.id !== request.id); - } - ); - }, - onError: (error: Error) => { - toast.error(error.message || "Failed to delete vision model"); - }, - }; -}); diff --git a/surfsense_web/atoms/vision-llm-config/vision-llm-config-query.atoms.ts b/surfsense_web/atoms/vision-llm-config/vision-llm-config-query.atoms.ts deleted file mode 100644 index 906ce638f..000000000 --- a/surfsense_web/atoms/vision-llm-config/vision-llm-config-query.atoms.ts +++ /dev/null @@ -1,51 +0,0 @@ -import { atomWithQuery } from "jotai-tanstack-query"; -import type { LLMModel } from "@/contracts/enums/llm-models"; -import { VISION_MODELS } from "@/contracts/enums/vision-providers"; -import { visionLLMConfigApiService } from "@/lib/apis/vision-llm-config-api.service"; -import { cacheKeys } from "@/lib/query-client/cache-keys"; -import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; - -export const visionLLMConfigsAtom = atomWithQuery((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - queryKey: cacheKeys.visionLLMConfigs.all(Number(searchSpaceId)), - enabled: !!searchSpaceId, - staleTime: 5 * 60 * 1000, - queryFn: async () => { - return visionLLMConfigApiService.getConfigs(Number(searchSpaceId)); - }, - }; -}); - -export const globalVisionLLMConfigsAtom = atomWithQuery(() => { - return { - queryKey: cacheKeys.visionLLMConfigs.global(), - staleTime: 10 * 60 * 1000, - queryFn: async () => { - return visionLLMConfigApiService.getGlobalConfigs(); - }, - }; -}); - -export const visionModelListAtom = atomWithQuery(() => { - return { - queryKey: cacheKeys.visionLLMConfigs.modelList(), - staleTime: 60 * 60 * 1000, - placeholderData: VISION_MODELS, - queryFn: async (): Promise => { - const data = await visionLLMConfigApiService.getModels(); - const dynamicModels = data.map((m) => ({ - value: m.value, - label: m.label, - provider: m.provider, - contextWindow: m.context_window ?? undefined, - })); - - const coveredProviders = new Set(dynamicModels.map((m) => m.provider)); - const staticFallbacks = VISION_MODELS.filter((m) => !coveredProviders.has(m.provider)); - - return [...dynamicModels, ...staticFallbacks]; - }, - }; -}); diff --git a/surfsense_web/components/new-chat/chat-header.tsx b/surfsense_web/components/new-chat/chat-header.tsx index 4716418ee..d65dc93a7 100644 --- a/surfsense_web/components/new-chat/chat-header.tsx +++ b/surfsense_web/components/new-chat/chat-header.tsx @@ -1,17 +1,5 @@ "use client"; -import { useCallback, useState } from "react"; -import { ImageConfigDialog } from "@/components/shared/image-config-dialog"; -import { ModelConfigDialog } from "@/components/shared/model-config-dialog"; -import { VisionConfigDialog } from "@/components/shared/vision-config-dialog"; -import type { - GlobalImageGenConfig, - GlobalNewLLMConfig, - GlobalVisionLLMConfig, - ImageGenerationConfig, - NewLLMConfigPublic, - VisionLLMConfig, -} from "@/contracts/types/new-llm-config.types"; import { ModelSelector } from "./model-selector"; interface ChatHeaderProps { @@ -20,148 +8,9 @@ interface ChatHeaderProps { } export function ChatHeader({ searchSpaceId, className }: ChatHeaderProps) { - // LLM config dialog state - const [dialogOpen, setDialogOpen] = useState(false); - const [selectedConfig, setSelectedConfig] = useState< - NewLLMConfigPublic | GlobalNewLLMConfig | null - >(null); - const [isGlobal, setIsGlobal] = useState(false); - const [dialogMode, setDialogMode] = useState<"create" | "edit" | "view">("view"); - - // Image config dialog state - const [imageDialogOpen, setImageDialogOpen] = useState(false); - const [selectedImageConfig, setSelectedImageConfig] = useState< - ImageGenerationConfig | GlobalImageGenConfig | null - >(null); - const [isImageGlobal, setIsImageGlobal] = useState(false); - const [imageDialogMode, setImageDialogMode] = useState<"create" | "edit" | "view">("view"); - - // Vision config dialog state - const [visionDialogOpen, setVisionDialogOpen] = useState(false); - const [selectedVisionConfig, setSelectedVisionConfig] = useState< - VisionLLMConfig | GlobalVisionLLMConfig | null - >(null); - const [isVisionGlobal, setIsVisionGlobal] = useState(false); - const [visionDialogMode, setVisionDialogMode] = useState<"create" | "edit" | "view">("view"); - - // Default provider for create dialogs - const [defaultLLMProvider, setDefaultLLMProvider] = useState(); - const [defaultImageProvider, setDefaultImageProvider] = useState(); - const [defaultVisionProvider, setDefaultVisionProvider] = useState(); - - // LLM handlers - const handleEditLLMConfig = useCallback( - (config: NewLLMConfigPublic | GlobalNewLLMConfig, global: boolean) => { - setSelectedConfig(config); - setIsGlobal(global); - setDialogMode(global ? "view" : "edit"); - setDefaultLLMProvider(undefined); - setDialogOpen(true); - }, - [] - ); - - const handleAddNewLLM = useCallback((provider?: string) => { - setSelectedConfig(null); - setIsGlobal(false); - setDialogMode("create"); - setDefaultLLMProvider(provider); - setDialogOpen(true); - }, []); - - const handleDialogClose = useCallback((open: boolean) => { - setDialogOpen(open); - if (!open) setSelectedConfig(null); - }, []); - - // Image model handlers - const handleAddImageModel = useCallback((provider?: string) => { - setSelectedImageConfig(null); - setIsImageGlobal(false); - setImageDialogMode("create"); - setDefaultImageProvider(provider); - setImageDialogOpen(true); - }, []); - - const handleEditImageConfig = useCallback( - (config: ImageGenerationConfig | GlobalImageGenConfig, global: boolean) => { - setSelectedImageConfig(config); - setIsImageGlobal(global); - setImageDialogMode(global ? "view" : "edit"); - setDefaultImageProvider(undefined); - setImageDialogOpen(true); - }, - [] - ); - - const handleImageDialogClose = useCallback((open: boolean) => { - setImageDialogOpen(open); - if (!open) setSelectedImageConfig(null); - }, []); - - // Vision model handlers - const handleAddVisionModel = useCallback((provider?: string) => { - setSelectedVisionConfig(null); - setIsVisionGlobal(false); - setVisionDialogMode("create"); - setDefaultVisionProvider(provider); - setVisionDialogOpen(true); - }, []); - - const handleEditVisionConfig = useCallback( - (config: VisionLLMConfig | GlobalVisionLLMConfig, global: boolean) => { - setSelectedVisionConfig(config); - setIsVisionGlobal(global); - setVisionDialogMode(global ? "view" : "edit"); - setDefaultVisionProvider(undefined); - setVisionDialogOpen(true); - }, - [] - ); - - const handleVisionDialogClose = useCallback((open: boolean) => { - setVisionDialogOpen(open); - if (!open) setSelectedVisionConfig(null); - }, []); - return (
- - - - +
); } diff --git a/surfsense_web/components/settings/agent-model-manager.tsx b/surfsense_web/components/settings/agent-model-manager.tsx deleted file mode 100644 index 507a263e0..000000000 --- a/surfsense_web/components/settings/agent-model-manager.tsx +++ /dev/null @@ -1,423 +0,0 @@ -"use client"; - -import { useAtomValue } from "jotai"; -import { AlertCircle, Dot, FileText, Info, Pencil, RefreshCw, Trash2 } from "lucide-react"; -import { useMemo, useState } from "react"; -import { membersAtom, myAccessAtom } from "@/atoms/members/members-query.atoms"; -import { deleteNewLLMConfigMutationAtom } from "@/atoms/new-llm-config/new-llm-config-mutation.atoms"; -import { - globalNewLLMConfigsAtom, - newLLMConfigsAtom, -} from "@/atoms/new-llm-config/new-llm-config-query.atoms"; -import { ModelConfigDialog } from "@/components/shared/model-config-dialog"; -import { Alert, AlertDescription } from "@/components/ui/alert"; -import { - AlertDialog, - AlertDialogAction, - AlertDialogCancel, - AlertDialogContent, - AlertDialogDescription, - AlertDialogFooter, - AlertDialogHeader, - AlertDialogTitle, -} from "@/components/ui/alert-dialog"; -import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar"; -import { Badge } from "@/components/ui/badge"; -import { Button } from "@/components/ui/button"; -import { Card, CardContent } from "@/components/ui/card"; -import { Separator } from "@/components/ui/separator"; -import { Skeleton } from "@/components/ui/skeleton"; -import { Spinner } from "@/components/ui/spinner"; -import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip"; -import type { NewLLMConfig } from "@/contracts/types/new-llm-config.types"; -import { useMediaQuery } from "@/hooks/use-media-query"; -import { getProviderIcon } from "@/lib/provider-icons"; -import { cn } from "@/lib/utils"; - -interface AgentModelManagerProps { - searchSpaceId: number; -} - -function getInitials(name: string): string { - const parts = name.trim().split(/\s+/); - if (parts.length >= 2) { - return (parts[0][0] + parts[1][0]).toUpperCase(); - } - return name.slice(0, 2).toUpperCase(); -} - -export function AgentModelManager({ searchSpaceId }: AgentModelManagerProps) { - const isDesktop = useMediaQuery("(min-width: 768px)"); - // Mutations - const { mutateAsync: deleteConfig, isPending: isDeleting } = useAtomValue( - deleteNewLLMConfigMutationAtom - ); - - // Queries - const { - data: configs, - isFetching: isLoading, - error: fetchError, - refetch: refreshConfigs, - } = useAtomValue(newLLMConfigsAtom); - const { data: globalConfigs = [] } = useAtomValue(globalNewLLMConfigsAtom); - - // Members for user resolution - const { data: members } = useAtomValue(membersAtom); - const memberMap = useMemo(() => { - const map = new Map(); - if (members) { - for (const m of members) { - map.set(m.user_id, { - name: m.user_display_name || m.user_email || "Unknown", - email: m.user_email || undefined, - avatarUrl: m.user_avatar_url || undefined, - }); - } - } - return map; - }, [members]); - - // Permissions - const { data: access } = useAtomValue(myAccessAtom); - const canCreate = - !!access && (access.is_owner || (access.permissions?.includes("llm_configs:create") ?? false)); - const canUpdate = - !!access && (access.is_owner || (access.permissions?.includes("llm_configs:update") ?? false)); - const canDelete = - !!access && (access.is_owner || (access.permissions?.includes("llm_configs:delete") ?? false)); - const isReadOnly = !canCreate && !canUpdate && !canDelete; - - // Local state - const [isDialogOpen, setIsDialogOpen] = useState(false); - const [editingConfig, setEditingConfig] = useState(null); - const [configToDelete, setConfigToDelete] = useState(null); - - const handleDelete = async () => { - if (!configToDelete) return; - try { - await deleteConfig({ id: configToDelete.id, name: configToDelete.name }); - setConfigToDelete(null); - } catch { - // Error handled by mutation state - } - }; - - const openEditDialog = (config: NewLLMConfig) => { - setEditingConfig(config); - setIsDialogOpen(true); - }; - - const openNewDialog = () => { - setEditingConfig(null); - setIsDialogOpen(true); - }; - - return ( -
- {/* Header actions */} -
- - {canCreate && ( - - )} -
- - {/* Fetch Error Alert */} - {fetchError && ( -
- - - - {fetchError?.message ?? "Failed to load configurations"} - - -
- )} - - {/* Read-only / Limited permissions notice */} - {access && !isLoading && isReadOnly && ( -
- - - -

- You have read-only access to LLM - configurations. Contact a space owner to request additional permissions. -

-
-
-
- )} - {access && !isLoading && !isReadOnly && (!canCreate || !canUpdate || !canDelete) && ( -
- - - -

- You can{" "} - {[canCreate && "create", canUpdate && "edit", canDelete && "delete"] - .filter(Boolean) - .join(" and ")}{" "} - configurations - {!canDelete && ", but cannot delete them"}. -

-
-
-
- )} - - {/* Global Configs Info */} - {(isLoading || globalConfigs.length > 0) && ( - - - - {isLoading ? ( -
- -
- ) : ( -

- - {globalConfigs.length} global {globalConfigs.length === 1 ? "model" : "models"} - {" "} - available from your administrator. -

- )} -
-
- )} - - {/* Loading Skeleton */} - {isLoading && ( -
- {["skeleton-a", "skeleton-b", "skeleton-c"].map((key) => ( - - - - - - - - ))} -
- )} - - {/* Configurations List */} - {!isLoading && ( -
- {configs?.length === 0 ? ( -
- - -

No Models Yet

-

- {canCreate - ? "Add your first model to power chat, reports, and other agent capabilities" - : "No models have been added to this space yet. Contact a space owner to add one"} -

-
-
-
- ) : ( -
- {configs?.map((config) => { - const member = config.user_id ? memberMap.get(config.user_id) : null; - - return ( -
- - - {/* Header: Icon + Name + Actions */} -
-
-
- {getProviderIcon(config.provider, { className: "size-4" })} -
-
-

- {config.name} -

- {config.description && ( -

- {config.description} -

- )} -
-
- {(canUpdate || canDelete) && ( -
- {canUpdate && ( - - - - - - Edit - - - )} - {canDelete && ( - - - - - - Delete - - - )} -
- )} -
- - {/* Feature badges */} -
- {config.citations_enabled && ( - - Citations - - )} - {!config.use_default_system_instructions && - config.system_instructions && ( - - - Custom - - )} -
- - {/* Footer: Date + Creator */} -
- -
- - {new Date(config.created_at).toLocaleDateString(undefined, { - year: "numeric", - month: "short", - day: "numeric", - })} - - {member && ( - <> - - - - -
- - {member.avatarUrl && ( - - )} - - {getInitials(member.name)} - - - - {member.name} - -
-
- - {member.email || member.name} - -
-
- - )} -
-
-
-
-
- ); - })} -
- )} -
- )} - - {/* Add/Edit Configuration Dialog */} - { - setIsDialogOpen(open); - if (!open) setEditingConfig(null); - }} - config={editingConfig} - isGlobal={false} - searchSpaceId={searchSpaceId} - mode={editingConfig ? "edit" : "create"} - /> - - {/* Delete Confirmation Dialog */} - !open && setConfigToDelete(null)} - > - - - Delete Model - - Are you sure you want to delete{" "} - {configToDelete?.name}? This - action cannot be undone. - - - - Cancel - - {isDeleting ? ( - <> - - Deleting - - ) : ( - "Delete" - )} - - - - -
- ); -} diff --git a/surfsense_web/components/settings/image-model-manager.tsx b/surfsense_web/components/settings/image-model-manager.tsx deleted file mode 100644 index 494f7aae9..000000000 --- a/surfsense_web/components/settings/image-model-manager.tsx +++ /dev/null @@ -1,489 +0,0 @@ -"use client"; - -import { useAtomValue } from "jotai"; -import { AlertCircle, Dot, Info, Pencil, RefreshCw, Trash2 } from "lucide-react"; -import { useMemo, useState } from "react"; -import { deleteImageGenConfigMutationAtom } from "@/atoms/image-gen-config/image-gen-config-mutation.atoms"; -import { - globalImageGenConfigsAtom, - imageGenConfigsAtom, -} from "@/atoms/image-gen-config/image-gen-config-query.atoms"; -import { membersAtom, myAccessAtom } from "@/atoms/members/members-query.atoms"; -import { ImageConfigDialog } from "@/components/shared/image-config-dialog"; -import { Alert, AlertDescription } from "@/components/ui/alert"; -import { - AlertDialog, - AlertDialogAction, - AlertDialogCancel, - AlertDialogContent, - AlertDialogDescription, - AlertDialogFooter, - AlertDialogHeader, - AlertDialogTitle, -} from "@/components/ui/alert-dialog"; -import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar"; -import { Badge } from "@/components/ui/badge"; -import { Button } from "@/components/ui/button"; -import { Card, CardContent } from "@/components/ui/card"; -import { Separator } from "@/components/ui/separator"; -import { Skeleton } from "@/components/ui/skeleton"; -import { Spinner } from "@/components/ui/spinner"; -import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip"; -import type { ImageGenerationConfig } from "@/contracts/types/new-llm-config.types"; -import { useMediaQuery } from "@/hooks/use-media-query"; -import { getProviderIcon } from "@/lib/provider-icons"; -import { cn } from "@/lib/utils"; - -interface ImageModelManagerProps { - searchSpaceId: number; -} - -function getInitials(name: string): string { - const parts = name.trim().split(/\s+/); - if (parts.length >= 2) { - return (parts[0][0] + parts[1][0]).toUpperCase(); - } - return name.slice(0, 2).toUpperCase(); -} - -export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) { - const isDesktop = useMediaQuery("(min-width: 768px)"); - - const { - mutateAsync: deleteConfig, - isPending: isDeleting, - error: deleteError, - } = useAtomValue(deleteImageGenConfigMutationAtom); - - const { - data: userConfigs, - isFetching: configsLoading, - error: fetchError, - refetch: refreshConfigs, - } = useAtomValue(imageGenConfigsAtom); - const { data: globalConfigs = [], isFetching: globalLoading } = - useAtomValue(globalImageGenConfigsAtom); - - const { data: members } = useAtomValue(membersAtom); - const memberMap = useMemo(() => { - const map = new Map(); - if (members) { - for (const m of members) { - map.set(m.user_id, { - name: m.user_display_name || m.user_email || "Unknown", - email: m.user_email || undefined, - avatarUrl: m.user_avatar_url || undefined, - }); - } - } - return map; - }, [members]); - - const { data: access } = useAtomValue(myAccessAtom); - const canCreate = - !!access && - (access.is_owner || (access.permissions?.includes("image_generations:create") ?? false)); - const canDelete = - !!access && - (access.is_owner || (access.permissions?.includes("image_generations:delete") ?? false)); - const canUpdate = canCreate; - const isReadOnly = !canCreate && !canDelete; - - const [isDialogOpen, setIsDialogOpen] = useState(false); - const [editingConfig, setEditingConfig] = useState(null); - const [configToDelete, setConfigToDelete] = useState(null); - - const isLoading = configsLoading || globalLoading; - const errors = [deleteError, fetchError].filter(Boolean) as Error[]; - - const openEditDialog = (config: ImageGenerationConfig) => { - setEditingConfig(config); - setIsDialogOpen(true); - }; - - const openNewDialog = () => { - setEditingConfig(null); - setIsDialogOpen(true); - }; - - const handleDelete = async () => { - if (!configToDelete) return; - try { - await deleteConfig({ id: configToDelete.id, name: configToDelete.name }); - setConfigToDelete(null); - } catch { - // Error handled by mutation - } - }; - - return ( -
- {/* Header actions */} -
- - {canCreate && ( - - )} -
- - {/* Errors */} - {errors.map((err) => ( -
- - - {err?.message} - -
- ))} - - {/* Read-only / Limited permissions notice */} - {access && !isLoading && isReadOnly && ( -
- - - -

- You have read-only access to image generation - configurations. Contact a space owner to request additional permissions. -

-
-
-
- )} - {access && !isLoading && !isReadOnly && (!canCreate || !canDelete) && ( -
- - - -

- You can{" "} - {[canCreate && "create and edit", canDelete && "delete"] - .filter(Boolean) - .join(" and ")}{" "} - image model configurations - {!canDelete && ", but cannot delete them"}. -

-
-
-
- )} - - {/* Global info */} - {(isLoading || - globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length > 0) && ( - - - - {isLoading ? ( -
- -
- ) : ( -

- - {globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length}{" "} - global image{" "} - {globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length === - 1 - ? "model" - : "models"} - {" "} - available from your administrator. {(() => { - const nonAuto = globalConfigs.filter( - (g) => !("is_auto_mode" in g && g.is_auto_mode) - ); - const premium = nonAuto.filter( - (g) => - "billing_tier" in g && - (g as { billing_tier?: string }).billing_tier === "premium" - ).length; - const free = nonAuto.length - premium; - if (premium > 0 && free > 0) { - return `${premium} premium, ${free} free.`; - } - if (premium > 0) { - return `All ${premium} premium — debits your shared credit pool.`; - } - return `All ${free} free.`; - })()} -

- )} -
-
- )} - - {/* Global Image Models — read-only cards with per-model Free/Premium - badges. Mirrors the badge palette used by the chat role selector - (`llm-role-manager.tsx`) so the meaning is consistent across - every model-configuration surface (chat / image / vision). */} - {!isLoading && - globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length > 0 && ( -
-
- {globalConfigs - .filter((g) => !("is_auto_mode" in g && g.is_auto_mode)) - .map((cfg) => { - const billingTier = - ("billing_tier" in cfg && - typeof (cfg as { billing_tier?: string }).billing_tier === "string" && - (cfg as { billing_tier?: string }).billing_tier) || - "free"; - const isPremium = billingTier === "premium"; - return ( - - -
-
- {getProviderIcon(cfg.provider, { className: "size-4" })} -
-
-

- {cfg.name} -

- {isPremium ? ( - - Premium - - ) : ( - - Free - - )} -
-
- {cfg.description && ( -

- {cfg.description} -

- )} -
- -
- - {cfg.model_name} - -
-
-
-
- ); - })} -
-
- )} - - {/* Loading Skeleton */} - {isLoading && ( -
-
-
- {["skeleton-a", "skeleton-b", "skeleton-c"].map((key) => ( - - - - - - - - ))} -
-
-
- )} - - {/* User Configs */} - {!isLoading && ( -
- {(userConfigs?.length ?? 0) === 0 ? ( - - -

No Image Models Yet

-

- {canCreate - ? "Add your own image generation model (DALL-E 3, GPT Image 1, etc.)" - : "No image models have been added to this space yet. Contact a space owner to add one."} -

-
-
- ) : ( -
- {userConfigs?.map((config) => { - const member = config.user_id ? memberMap.get(config.user_id) : null; - - return ( -
- - - {/* Header: Icon + Name + Actions */} -
-
-
- {getProviderIcon(config.provider, { className: "size-4" })} -
-
-

- {config.name} -

- {config.description && ( -

- {config.description} -

- )} -
-
- {(canUpdate || canDelete) && ( -
- {canUpdate && ( - - - - - - Edit - - - )} - {canDelete && ( - - - - - - Delete - - - )} -
- )} -
- - {/* Footer: Date + Creator */} -
- -
- - {new Date(config.created_at).toLocaleDateString(undefined, { - year: "numeric", - month: "short", - day: "numeric", - })} - - {member && ( - <> - - - - -
- - {member.avatarUrl && ( - - )} - - {getInitials(member.name)} - - - - {member.name} - -
-
- - {member.email || member.name} - -
-
- - )} -
-
-
-
-
- ); - })} -
- )} -
- )} - - {/* Create/Edit Dialog — shared component */} - { - setIsDialogOpen(open); - if (!open) setEditingConfig(null); - }} - config={editingConfig} - isGlobal={false} - searchSpaceId={searchSpaceId} - mode={editingConfig ? "edit" : "create"} - /> - - {/* Delete Confirmation */} - !open && setConfigToDelete(null)} - > - - - Delete Image Model - - Are you sure you want to delete{" "} - {configToDelete?.name}? - - - - Cancel - - Delete - {isDeleting && } - - - - -
- ); -} diff --git a/surfsense_web/components/settings/llm-role-manager.tsx b/surfsense_web/components/settings/llm-role-manager.tsx deleted file mode 100644 index 547675927..000000000 --- a/surfsense_web/components/settings/llm-role-manager.tsx +++ /dev/null @@ -1,443 +0,0 @@ -"use client"; - -import { useAtomValue } from "jotai"; -import { - AlertCircle, - Bot, - CircleCheck, - CircleDashed, - FileText, - ImageIcon, - RefreshCw, - ScanEye, -} from "lucide-react"; -import { useCallback, useEffect, useState } from "react"; -import { toast } from "sonner"; -import { - globalImageGenConfigsAtom, - imageGenConfigsAtom, -} from "@/atoms/image-gen-config/image-gen-config-query.atoms"; -import { updateLLMPreferencesMutationAtom } from "@/atoms/new-llm-config/new-llm-config-mutation.atoms"; -import { - globalNewLLMConfigsAtom, - llmPreferencesAtom, - newLLMConfigsAtom, -} from "@/atoms/new-llm-config/new-llm-config-query.atoms"; -import { - globalVisionLLMConfigsAtom, - visionLLMConfigsAtom, -} from "@/atoms/vision-llm-config/vision-llm-config-query.atoms"; -import { Alert, AlertDescription } from "@/components/ui/alert"; -import { Badge } from "@/components/ui/badge"; -import { Button } from "@/components/ui/button"; -import { Card, CardContent } from "@/components/ui/card"; -import { Label } from "@/components/ui/label"; -import { - Select, - SelectContent, - SelectGroup, - SelectItem, - SelectLabel, - SelectTrigger, - SelectValue, -} from "@/components/ui/select"; -import { Skeleton } from "@/components/ui/skeleton"; -import { Spinner } from "@/components/ui/spinner"; -import { cn } from "@/lib/utils"; - -const ROLE_DESCRIPTIONS = { - agent: { - icon: Bot, - title: "Chat model", - description: "Primary model for chat interactions and agent operations", - color: "text-muted-foreground", - bgColor: "bg-muted", - prefKey: "agent_llm_id" as const, - configType: "llm" as const, - }, - image_generation: { - icon: ImageIcon, - title: "Image Generation Model", - description: "Model used for AI image generation (DALL-E, GPT Image, etc.)", - color: "text-muted-foreground", - bgColor: "bg-muted", - prefKey: "image_generation_config_id" as const, - configType: "image" as const, - }, - vision: { - icon: ScanEye, - title: "Vision LLM", - description: "Vision-capable model for screenshot analysis and context extraction", - color: "text-muted-foreground", - bgColor: "bg-muted", - prefKey: "vision_llm_config_id" as const, - configType: "vision" as const, - }, -}; - -interface LLMRoleManagerProps { - searchSpaceId: number; -} - -export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) { - // LLM configs - const { - data: newLLMConfigs = [], - isFetching: configsLoading, - error: configsError, - refetch: refreshConfigs, - } = useAtomValue(newLLMConfigsAtom); - const { - data: globalConfigs = [], - isFetching: globalConfigsLoading, - error: globalConfigsError, - } = useAtomValue(globalNewLLMConfigsAtom); - - // Image gen configs - const { - data: userImageConfigs = [], - isFetching: imageConfigsLoading, - error: imageConfigsError, - } = useAtomValue(imageGenConfigsAtom); - const { - data: globalImageConfigs = [], - isFetching: globalImageConfigsLoading, - error: globalImageConfigsError, - } = useAtomValue(globalImageGenConfigsAtom); - - // Vision LLM configs - const { - data: userVisionConfigs = [], - isFetching: visionConfigsLoading, - error: visionConfigsError, - } = useAtomValue(visionLLMConfigsAtom); - const { - data: globalVisionConfigs = [], - isFetching: globalVisionConfigsLoading, - error: globalVisionConfigsError, - } = useAtomValue(globalVisionLLMConfigsAtom); - - // Preferences - const { - data: preferences = {}, - isFetching: preferencesLoading, - error: preferencesError, - } = useAtomValue(llmPreferencesAtom); - - const { mutateAsync: updatePreferences } = useAtomValue(updateLLMPreferencesMutationAtom); - - const [assignments, setAssignments] = useState>(() => ({ - agent_llm_id: preferences.agent_llm_id ?? null, - image_generation_config_id: preferences.image_generation_config_id ?? null, - vision_llm_config_id: preferences.vision_llm_config_id ?? null, - })); - - // Sync local state when preferences load/change. Without this, the selects - // stay on their initial (often empty) value while the query is in flight, - // so a saved assignment — including Auto mode (id 0) — never appears. - useEffect(() => { - setAssignments({ - agent_llm_id: preferences.agent_llm_id ?? null, - image_generation_config_id: preferences.image_generation_config_id ?? null, - vision_llm_config_id: preferences.vision_llm_config_id ?? null, - }); - }, [ - preferences.agent_llm_id, - preferences.image_generation_config_id, - preferences.vision_llm_config_id, - ]); - - const [savingRole, setSavingRole] = useState(null); - - const handleRoleAssignment = useCallback( - async (prefKey: string, configId: string) => { - // "unassigned" clears the role (null). Every other option — including - // Auto mode, whose config id is 0 — must be sent as-is. Using a falsy - // check here (e.g. `value || undefined`) would drop id 0 and silently - // fail to persist Auto mode. - const value = configId === "unassigned" ? null : Number(configId); - - setAssignments((prev) => ({ ...prev, [prefKey]: value })); - setSavingRole(prefKey); - - try { - await updatePreferences({ - search_space_id: searchSpaceId, - data: { [prefKey]: value }, - }); - toast.success("Role assignment updated"); - } finally { - setSavingRole(null); - } - }, - [updatePreferences, searchSpaceId] - ); - - // Combine global and custom LLM configs - const allLLMConfigs = [ - ...globalConfigs.map((config) => ({ ...config, is_global: true })), - ...newLLMConfigs.filter((config) => config.id && config.id.toString().trim() !== ""), - ]; - - // Combine global and custom image gen configs - const allImageConfigs = [ - ...globalImageConfigs.map((config) => ({ ...config, is_global: true })), - ...(userImageConfigs ?? []).filter((config) => config.id && config.id.toString().trim() !== ""), - ]; - - // Combine global and custom vision LLM configs - const allVisionConfigs = [ - ...globalVisionConfigs.map((config) => ({ ...config, is_global: true })), - ...(userVisionConfigs ?? []).filter( - (config) => config.id && config.id.toString().trim() !== "" - ), - ]; - - const isLoading = - configsLoading || - preferencesLoading || - globalConfigsLoading || - imageConfigsLoading || - globalImageConfigsLoading || - visionConfigsLoading || - globalVisionConfigsLoading; - const hasError = - configsError || - preferencesError || - globalConfigsError || - imageConfigsError || - globalImageConfigsError || - visionConfigsError || - globalVisionConfigsError; - const hasAnyConfigs = allLLMConfigs.length > 0 || allImageConfigs.length > 0; - - return ( -
- {/* Header actions */} -
- -
- - {/* Error Alert */} - {hasError && ( -
- - - - {(configsError?.message ?? "Failed to load LLM configurations") || - (preferencesError?.message ?? "Failed to load preferences") || - (globalConfigsError?.message ?? "Failed to load global configurations")} - - -
- )} - - {/* Loading Skeleton */} - {isLoading && ( -
- {["skeleton-a", "skeleton-b", "skeleton-c"].map((key) => ( - - - - - - - - ))} -
- )} - - {/* No configs warning */} - {!isLoading && !hasError && !hasAnyConfigs && ( - - - - No configurations found. Please add at least one LLM provider or image model in the - respective settings tabs before assigning roles. - - - )} - - {/* Role Assignment Cards */} - {!isLoading && !hasError && hasAnyConfigs && ( -
- {Object.entries(ROLE_DESCRIPTIONS).map(([key, role]) => { - const IconComponent = role.icon; - const currentAssignment = assignments[role.prefKey as keyof typeof assignments]; - - // Pick the right config lists based on role type - const roleGlobalConfigs = - role.configType === "image" - ? globalImageConfigs - : role.configType === "vision" - ? globalVisionConfigs - : globalConfigs; - const roleUserConfigs = - role.configType === "image" - ? (userImageConfigs ?? []).filter((c) => c.id && c.id.toString().trim() !== "") - : role.configType === "vision" - ? (userVisionConfigs ?? []).filter((c) => c.id && c.id.toString().trim() !== "") - : newLLMConfigs.filter((c) => c.id && c.id.toString().trim() !== ""); - const roleAllConfigs = - role.configType === "image" - ? allImageConfigs - : role.configType === "vision" - ? allVisionConfigs - : allLLMConfigs; - - const assignedConfig = roleAllConfigs.find((config) => config.id === currentAssignment); - const isAssigned = !!assignedConfig; - - return ( -
- - - {/* Role Header */} -
-
-
- -
-
-

{role.title}

-

- {role.description} -

-
-
- {savingRole === role.prefKey ? ( - - ) : isAssigned ? ( - - ) : ( - - )} -
- - {/* Selector */} -
- - -
-
-
-
- ); - })} -
- )} -
- ); -} diff --git a/surfsense_web/components/settings/vision-model-manager.tsx b/surfsense_web/components/settings/vision-model-manager.tsx deleted file mode 100644 index 31578b4f1..000000000 --- a/surfsense_web/components/settings/vision-model-manager.tsx +++ /dev/null @@ -1,486 +0,0 @@ -"use client"; - -import { useAtomValue } from "jotai"; -import { AlertCircle, Dot, Info, Pencil, RefreshCw, Trash2 } from "lucide-react"; -import { useMemo, useState } from "react"; -import { membersAtom, myAccessAtom } from "@/atoms/members/members-query.atoms"; -import { deleteVisionLLMConfigMutationAtom } from "@/atoms/vision-llm-config/vision-llm-config-mutation.atoms"; -import { - globalVisionLLMConfigsAtom, - visionLLMConfigsAtom, -} from "@/atoms/vision-llm-config/vision-llm-config-query.atoms"; -import { VisionConfigDialog } from "@/components/shared/vision-config-dialog"; -import { Alert, AlertDescription } from "@/components/ui/alert"; -import { - AlertDialog, - AlertDialogAction, - AlertDialogCancel, - AlertDialogContent, - AlertDialogDescription, - AlertDialogFooter, - AlertDialogHeader, - AlertDialogTitle, -} from "@/components/ui/alert-dialog"; -import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar"; -import { Badge } from "@/components/ui/badge"; -import { Button } from "@/components/ui/button"; -import { Card, CardContent } from "@/components/ui/card"; -import { Separator } from "@/components/ui/separator"; -import { Skeleton } from "@/components/ui/skeleton"; -import { Spinner } from "@/components/ui/spinner"; -import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip"; -import type { VisionLLMConfig } from "@/contracts/types/new-llm-config.types"; -import { useMediaQuery } from "@/hooks/use-media-query"; -import { getProviderIcon } from "@/lib/provider-icons"; -import { cn } from "@/lib/utils"; - -interface VisionModelManagerProps { - searchSpaceId: number; -} - -function getInitials(name: string): string { - const parts = name.trim().split(/\s+/); - if (parts.length >= 2) { - return (parts[0][0] + parts[1][0]).toUpperCase(); - } - return name.slice(0, 2).toUpperCase(); -} - -export function VisionModelManager({ searchSpaceId }: VisionModelManagerProps) { - const isDesktop = useMediaQuery("(min-width: 768px)"); - - const { - mutateAsync: deleteConfig, - isPending: isDeleting, - error: deleteError, - } = useAtomValue(deleteVisionLLMConfigMutationAtom); - - const { - data: userConfigs, - isFetching: configsLoading, - error: fetchError, - refetch: refreshConfigs, - } = useAtomValue(visionLLMConfigsAtom); - const { data: globalConfigs = [], isFetching: globalLoading } = useAtomValue( - globalVisionLLMConfigsAtom - ); - - const { data: members } = useAtomValue(membersAtom); - const memberMap = useMemo(() => { - const map = new Map(); - if (members) { - for (const m of members) { - map.set(m.user_id, { - name: m.user_display_name || m.user_email || "Unknown", - email: m.user_email || undefined, - avatarUrl: m.user_avatar_url || undefined, - }); - } - } - return map; - }, [members]); - - const { data: access } = useAtomValue(myAccessAtom); - const canCreate = useMemo(() => { - if (!access) return false; - if (access.is_owner) return true; - return access.permissions?.includes("vision_configs:create") ?? false; - }, [access]); - const canDelete = useMemo(() => { - if (!access) return false; - if (access.is_owner) return true; - return access.permissions?.includes("vision_configs:delete") ?? false; - }, [access]); - const canUpdate = canCreate; - const isReadOnly = !canCreate && !canDelete; - - const [isDialogOpen, setIsDialogOpen] = useState(false); - const [editingConfig, setEditingConfig] = useState(null); - const [configToDelete, setConfigToDelete] = useState(null); - - const isLoading = configsLoading || globalLoading; - const errors = [deleteError, fetchError].filter(Boolean) as Error[]; - - const openEditDialog = (config: VisionLLMConfig) => { - setEditingConfig(config); - setIsDialogOpen(true); - }; - - const openNewDialog = () => { - setEditingConfig(null); - setIsDialogOpen(true); - }; - - const handleDelete = async () => { - if (!configToDelete) return; - try { - await deleteConfig({ id: configToDelete.id, name: configToDelete.name }); - setConfigToDelete(null); - } catch { - // Error handled by mutation - } - }; - - return ( -
-
- - {canCreate && ( - - )} -
- - {errors.map((err) => ( -
- - - {err?.message} - -
- ))} - - {access && !isLoading && isReadOnly && ( -
- - - -

- You have read-only access to vision model - configurations. Contact a space owner to request additional permissions. -

-
-
-
- )} - {access && !isLoading && !isReadOnly && (!canCreate || !canDelete) && ( -
- - - -

- You can{" "} - {[canCreate && "create and edit", canDelete && "delete"] - .filter(Boolean) - .join(" and ")}{" "} - vision model configurations - {!canDelete && ", but cannot delete them"}. -

-
-
-
- )} - - {(isLoading || - globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length > 0) && ( - - - - {isLoading ? ( -
- -
- ) : ( -

- - {globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length}{" "} - global vision{" "} - {globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length === - 1 - ? "model" - : "models"} - {" "} - available from your administrator. {(() => { - const nonAuto = globalConfigs.filter( - (g) => !("is_auto_mode" in g && g.is_auto_mode) - ); - const premium = nonAuto.filter( - (g) => - "billing_tier" in g && - (g as { billing_tier?: string }).billing_tier === "premium" - ).length; - const free = nonAuto.length - premium; - if (premium > 0 && free > 0) { - return `${premium} premium, ${free} free.`; - } - if (premium > 0) { - return `All ${premium} premium — debits your shared credit pool.`; - } - return `All ${free} free.`; - })()} -

- )} -
-
- )} - - {/* Global Vision Models — read-only cards with per-model Free/Premium - badges. Mirrors the badge palette used by the chat role selector - (`llm-role-manager.tsx`) so the meaning is consistent across - every model-configuration surface (chat / image / vision). */} - {!isLoading && - globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length > 0 && ( -
-
- {globalConfigs - .filter((g) => !("is_auto_mode" in g && g.is_auto_mode)) - .map((cfg) => { - const billingTier = - ("billing_tier" in cfg && - typeof (cfg as { billing_tier?: string }).billing_tier === "string" && - (cfg as { billing_tier?: string }).billing_tier) || - "free"; - const isPremium = billingTier === "premium"; - return ( - - -
-
- {getProviderIcon(cfg.provider, { className: "size-4" })} -
-
-

- {cfg.name} -

- {isPremium ? ( - - Premium - - ) : ( - - Free - - )} -
-
- {cfg.description && ( -

- {cfg.description} -

- )} -
- -
- - {cfg.model_name} - -
-
-
-
- ); - })} -
-
- )} - - {isLoading && ( -
-
-
- {["skeleton-a", "skeleton-b", "skeleton-c"].map((key) => ( - - - - - - - - ))} -
-
-
- )} - - {!isLoading && ( -
- {(userConfigs?.length ?? 0) === 0 ? ( - - -

No Vision Models Yet

-

- {canCreate - ? "Add your own vision-capable model (GPT-4o, Claude, Gemini, etc.)" - : "No vision models have been added to this space yet. Contact a space owner to add one."} -

-
-
- ) : ( -
- {userConfigs?.map((config) => { - const member = config.user_id ? memberMap.get(config.user_id) : null; - - return ( -
- - - {/* Header: Icon + Name + Actions */} -
-
-
- {getProviderIcon(config.provider, { className: "size-4" })} -
-
-

- {config.name} -

- {config.description && ( -

- {config.description} -

- )} -
-
- {(canUpdate || canDelete) && ( -
- {canUpdate && ( - - - - - - Edit - - - )} - {canDelete && ( - - - - - - Delete - - - )} -
- )} -
- - {/* Footer: Date + Creator */} -
- -
- - {new Date(config.created_at).toLocaleDateString(undefined, { - year: "numeric", - month: "short", - day: "numeric", - })} - - {member && ( - <> - - - - -
- - {member.avatarUrl && ( - - )} - - {getInitials(member.name)} - - - - {member.name} - -
-
- - {member.email || member.name} - -
-
- - )} -
-
-
-
-
- ); - })} -
- )} -
- )} - - { - setIsDialogOpen(open); - if (!open) setEditingConfig(null); - }} - config={editingConfig} - isGlobal={false} - searchSpaceId={searchSpaceId} - mode={editingConfig ? "edit" : "create"} - /> - - !open && setConfigToDelete(null)} - > - - - Delete Vision Model - - Are you sure you want to delete{" "} - {configToDelete?.name}? - - - - Cancel - - Delete - {isDeleting && } - - - - -
- ); -} diff --git a/surfsense_web/components/shared/image-config-dialog.tsx b/surfsense_web/components/shared/image-config-dialog.tsx deleted file mode 100644 index 36d16081a..000000000 --- a/surfsense_web/components/shared/image-config-dialog.tsx +++ /dev/null @@ -1,456 +0,0 @@ -"use client"; - -import { useAtomValue } from "jotai"; -import { AlertCircle, Check, ChevronsUpDown } from "lucide-react"; -import { useCallback, useEffect, useMemo, useRef, useState } from "react"; -import { toast } from "sonner"; -import { - createImageGenConfigMutationAtom, - updateImageGenConfigMutationAtom, -} from "@/atoms/image-gen-config/image-gen-config-mutation.atoms"; -import { updateLLMPreferencesMutationAtom } from "@/atoms/new-llm-config/new-llm-config-mutation.atoms"; -import { Alert, AlertDescription } from "@/components/ui/alert"; -import { Badge } from "@/components/ui/badge"; -import { Button } from "@/components/ui/button"; -import { - Command, - CommandEmpty, - CommandGroup, - CommandInput, - CommandItem, - CommandList, -} from "@/components/ui/command"; -import { Dialog, DialogContent, DialogTitle } from "@/components/ui/dialog"; -import { Input } from "@/components/ui/input"; -import { Label } from "@/components/ui/label"; -import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; -import { - Select, - SelectContent, - SelectItem, - SelectTrigger, - SelectValue, -} from "@/components/ui/select"; -import { Separator } from "@/components/ui/separator"; -import { Spinner } from "@/components/ui/spinner"; -import { IMAGE_GEN_MODELS, IMAGE_GEN_PROVIDERS } from "@/contracts/enums/image-gen-providers"; -import type { - GlobalImageGenConfig, - ImageGenerationConfig, - ImageGenProvider, -} from "@/contracts/types/new-llm-config.types"; -import { cn } from "@/lib/utils"; - -interface ImageConfigDialogProps { - open: boolean; - onOpenChange: (open: boolean) => void; - config: ImageGenerationConfig | GlobalImageGenConfig | null; - isGlobal: boolean; - searchSpaceId: number; - mode: "create" | "edit" | "view"; - defaultProvider?: string; -} - -const INITIAL_FORM = { - name: "", - description: "", - provider: "", - model_name: "", - api_key: "", - api_base: "", - api_version: "", -}; - -export function ImageConfigDialog({ - open, - onOpenChange, - config, - isGlobal, - searchSpaceId, - mode, - defaultProvider, -}: ImageConfigDialogProps) { - const [isSubmitting, setIsSubmitting] = useState(false); - const [formData, setFormData] = useState(INITIAL_FORM); - const [modelComboboxOpen, setModelComboboxOpen] = useState(false); - const [scrollPos, setScrollPos] = useState<"top" | "middle" | "bottom">("top"); - const scrollRef = useRef(null); - - useEffect(() => { - if (open) { - if (mode === "edit" && config && !isGlobal) { - setFormData({ - name: config.name || "", - description: config.description || "", - provider: config.provider || "", - model_name: config.model_name || "", - api_key: (config as ImageGenerationConfig).api_key || "", - api_base: config.api_base || "", - api_version: config.api_version || "", - }); - } else if (mode === "create") { - setFormData({ ...INITIAL_FORM, provider: defaultProvider ?? "" }); - } - setScrollPos("top"); - } - }, [open, mode, config, isGlobal, defaultProvider]); - - const { mutateAsync: createConfig } = useAtomValue(createImageGenConfigMutationAtom); - const { mutateAsync: updateConfig } = useAtomValue(updateImageGenConfigMutationAtom); - const { mutateAsync: updatePreferences } = useAtomValue(updateLLMPreferencesMutationAtom); - - const handleScroll = useCallback((e: React.UIEvent) => { - const el = e.currentTarget; - const atTop = el.scrollTop <= 2; - const atBottom = el.scrollHeight - el.scrollTop - el.clientHeight <= 2; - setScrollPos(atTop ? "top" : atBottom ? "bottom" : "middle"); - }, []); - - const suggestedModels = useMemo(() => { - if (!formData.provider) return []; - return IMAGE_GEN_MODELS.filter((m) => m.provider === formData.provider); - }, [formData.provider]); - - const getTitle = () => { - if (mode === "create") return "Add Image Model"; - if (isGlobal) return "View Global Image Model"; - return "Edit Image Model"; - }; - - const getSubtitle = () => { - if (mode === "create") return "Set up a new image generation provider"; - if (isGlobal) return "Read-only global configuration"; - return "Update your image model settings"; - }; - - const handleSubmit = useCallback(async () => { - setIsSubmitting(true); - try { - if (mode === "create") { - const result = await createConfig({ - name: formData.name, - provider: formData.provider as ImageGenProvider, - model_name: formData.model_name, - api_key: formData.api_key, - api_base: formData.api_base || undefined, - api_version: formData.api_version || undefined, - description: formData.description || undefined, - search_space_id: searchSpaceId, - }); - if (result?.id) { - await updatePreferences({ - search_space_id: searchSpaceId, - data: { image_generation_config_id: result.id }, - }); - } - onOpenChange(false); - } else if (!isGlobal && config) { - await updateConfig({ - id: config.id, - data: { - name: formData.name, - description: formData.description || undefined, - provider: formData.provider as ImageGenProvider, - model_name: formData.model_name, - api_key: formData.api_key, - api_base: formData.api_base || undefined, - api_version: formData.api_version || undefined, - }, - }); - onOpenChange(false); - } - } catch (error) { - console.error("Failed to save image config:", error); - toast.error("Failed to save image model"); - } finally { - setIsSubmitting(false); - } - }, [ - mode, - isGlobal, - config, - formData, - searchSpaceId, - createConfig, - updateConfig, - updatePreferences, - onOpenChange, - ]); - - const handleUseGlobalConfig = useCallback(async () => { - if (!config || !isGlobal) return; - setIsSubmitting(true); - try { - await updatePreferences({ - search_space_id: searchSpaceId, - data: { image_generation_config_id: config.id }, - }); - toast.success(`Now using ${config.name}`); - onOpenChange(false); - } catch (error) { - console.error("Failed to set image model:", error); - toast.error("Failed to set image model"); - } finally { - setIsSubmitting(false); - } - }, [config, isGlobal, searchSpaceId, updatePreferences, onOpenChange]); - - const isFormValid = formData.name && formData.provider && formData.model_name && formData.api_key; - const selectedProvider = IMAGE_GEN_PROVIDERS.find((p) => p.value === formData.provider); - - return ( - - e.preventDefault()} - > - {getTitle()} - - {/* Header */} -
-
-
-

{getTitle()}

- {isGlobal && mode !== "create" && ( - - Global - - )} -
-

{getSubtitle()}

- {config && mode !== "create" && ( -

{config.model_name}

- )} -
-
- - {/* Scrollable content */} -
- {isGlobal && config && ( - <> - - - - Global configurations are read-only. To customize, create a new model. - - -
-
-
-
- Name -
-

{config.name}

-
- {config.description && ( -
-
- Description -
-

{config.description}

-
- )} -
- -
-
-
- Provider -
-

{config.provider}

-
-
-
- Model -
-

{config.model_name}

-
-
-
- - )} - - {(mode === "create" || (mode === "edit" && !isGlobal)) && ( -
-
- - setFormData((p) => ({ ...p, name: e.target.value }))} - /> -
- -
- - setFormData((p) => ({ ...p, description: e.target.value }))} - /> -
- - - -
- - -
- -
- - {suggestedModels.length > 0 ? ( - - - - - - - setFormData((p) => ({ ...p, model_name: val }))} - /> - - - - Type a custom model name - - - - {suggestedModels.map((m) => ( - { - setFormData((p) => ({ ...p, model_name: m.value })); - setModelComboboxOpen(false); - }} - > - - {m.value} - - {m.label} - - - ))} - - - - - - ) : ( - setFormData((p) => ({ ...p, model_name: e.target.value }))} - /> - )} -
- -
- - setFormData((p) => ({ ...p, api_key: e.target.value }))} - /> -
- -
- - setFormData((p) => ({ ...p, api_base: e.target.value }))} - /> -
- - {formData.provider === "AZURE_OPENAI" && ( -
- - setFormData((p) => ({ ...p, api_version: e.target.value }))} - /> -
- )} -
- )} -
- - {/* Fixed footer */} -
- - {mode === "create" || (mode === "edit" && !isGlobal) ? ( - - ) : isGlobal && config ? ( - - ) : null} -
-
-
- ); -} diff --git a/surfsense_web/components/shared/llm-config-form.tsx b/surfsense_web/components/shared/llm-config-form.tsx deleted file mode 100644 index 06de4129b..000000000 --- a/surfsense_web/components/shared/llm-config-form.tsx +++ /dev/null @@ -1,527 +0,0 @@ -"use client"; - -import { zodResolver } from "@hookform/resolvers/zod"; -import { useAtomValue } from "jotai"; -import { Check, ChevronDown, ChevronsUpDown } from "lucide-react"; -import { useEffect, useMemo, useState } from "react"; -import { type Resolver, useForm } from "react-hook-form"; -import { z } from "zod"; -import { - defaultSystemInstructionsAtom, - modelListAtom, -} from "@/atoms/new-llm-config/new-llm-config-query.atoms"; -import { Badge } from "@/components/ui/badge"; -import { Button } from "@/components/ui/button"; -import { Collapsible, CollapsibleContent, CollapsibleTrigger } from "@/components/ui/collapsible"; -import { - Command, - CommandEmpty, - CommandGroup, - CommandInput, - CommandItem, - CommandList, -} from "@/components/ui/command"; -import { - Form, - FormControl, - FormDescription, - FormField, - FormItem, - FormLabel, - FormMessage, -} from "@/components/ui/form"; -import { Input } from "@/components/ui/input"; -import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; -import { - Select, - SelectContent, - SelectItem, - SelectTrigger, - SelectValue, -} from "@/components/ui/select"; -import { Separator } from "@/components/ui/separator"; -import { Switch } from "@/components/ui/switch"; -import { Textarea } from "@/components/ui/textarea"; -import { LLM_PROVIDERS } from "@/contracts/enums/llm-providers"; -import type { CreateNewLLMConfigRequest } from "@/contracts/types/new-llm-config.types"; -import { cn } from "@/lib/utils"; -import InferenceParamsEditor from "../inference-params-editor"; - -// Form schema with zod -const formSchema = z.object({ - name: z.string().min(1, "Name is required").max(100), - description: z.string().max(500).optional().nullable(), - provider: z.string().min(1, "Provider is required"), - custom_provider: z.string().max(100).optional().nullable(), - model_name: z.string().min(1, "Model name is required").max(100), - api_key: z.string().min(1, "API key is required"), - api_base: z.string().max(500).optional().nullable(), - litellm_params: z.record(z.string(), z.any()).optional().nullable(), - system_instructions: z.string().default(""), - use_default_system_instructions: z.boolean().default(true), - citations_enabled: z.boolean().default(true), - search_space_id: z.number(), -}); - -type FormValues = z.infer; - -export type LLMConfigFormData = CreateNewLLMConfigRequest; - -interface LLMConfigFormProps { - initialData?: Partial; - searchSpaceId: number; - onSubmit: (data: LLMConfigFormData) => Promise; - mode?: "create" | "edit"; - showAdvanced?: boolean; - formId?: string; -} - -export function LLMConfigForm({ - initialData, - searchSpaceId, - onSubmit, - mode = "create", - showAdvanced = true, - formId, -}: LLMConfigFormProps) { - const { data: defaultInstructions, isSuccess: defaultInstructionsLoaded } = useAtomValue( - defaultSystemInstructionsAtom - ); - const { data: dynamicModels } = useAtomValue(modelListAtom); - const [modelComboboxOpen, setModelComboboxOpen] = useState(false); - const [advancedOpen, setAdvancedOpen] = useState(false); - const [systemInstructionsOpen, setSystemInstructionsOpen] = useState(false); - - const form = useForm({ - resolver: zodResolver(formSchema) as Resolver, - defaultValues: { - name: initialData?.name ?? "", - description: initialData?.description ?? "", - provider: initialData?.provider ?? "", - custom_provider: initialData?.custom_provider ?? "", - model_name: initialData?.model_name ?? "", - api_key: initialData?.api_key ?? "", - api_base: initialData?.api_base ?? "", - litellm_params: initialData?.litellm_params ?? {}, - system_instructions: initialData?.system_instructions ?? "", - use_default_system_instructions: initialData?.use_default_system_instructions ?? true, - citations_enabled: initialData?.citations_enabled ?? true, - search_space_id: searchSpaceId, - }, - }); - - // Load default instructions when available (only for new configs) - useEffect(() => { - if ( - mode === "create" && - defaultInstructionsLoaded && - defaultInstructions?.default_system_instructions && - !form.getValues("system_instructions") - ) { - form.setValue("system_instructions", defaultInstructions.default_system_instructions); - } - }, [defaultInstructionsLoaded, defaultInstructions, mode, form]); - - const watchProvider = form.watch("provider"); - const selectedProvider = LLM_PROVIDERS.find((p) => p.value === watchProvider); - const availableModels = useMemo( - () => (dynamicModels ?? []).filter((m) => m.provider === watchProvider), - [dynamicModels, watchProvider] - ); - - const handleProviderChange = (value: string) => { - form.setValue("provider", value); - form.setValue("model_name", ""); - - // Auto-fill API base for certain providers - const provider = LLM_PROVIDERS.find((p) => p.value === value); - if (provider?.apiBase) { - form.setValue("api_base", provider.apiBase); - } - }; - - const handleFormSubmit = async (values: FormValues) => { - await onSubmit(values as LLMConfigFormData); - }; - - return ( -
- - {/* Model Configuration Section */} -
-
- Model Configuration -
- - {/* Name & Description */} -
- ( - - Configuration Name - - - - - - )} - /> - - ( - - - Description - - Optional - - - - - - - - )} - /> -
- - {/* Provider Selection */} - ( - - LLM Provider - - - - )} - /> - - {/* Custom Provider (conditional) */} - {watchProvider === "CUSTOM" && ( - ( - - Custom Provider Name - - - - - - )} - /> - )} - - {/* Model Name with Combobox */} - ( - - Model Name - - - - - - - - - - - -
- {field.value ? `Using: "${field.value}"` : "Type your model name"} -
-
- {availableModels.length > 0 && ( - - {availableModels - .filter( - (model) => - !field.value || - model.value.toLowerCase().includes(field.value.toLowerCase()) || - model.label.toLowerCase().includes(field.value.toLowerCase()) - ) - .slice(0, 50) - .map((model) => ( - { - field.onChange(value); - setModelComboboxOpen(false); - }} - className="py-2" - > - -
-
{model.label}
- {model.contextWindow && ( -
- Context: {model.contextWindow} -
- )} -
-
- ))} -
- )} -
-
-
-
- {selectedProvider?.example && ( - - Example: {selectedProvider.example} - - )} - -
- )} - /> - - {/* API Credentials */} -
- ( - - API Key - - - - {watchProvider === "OLLAMA" && ( - - Ollama doesn't require auth — enter any value - - )} - - - )} - /> - - ( - - - API Base URL - {selectedProvider?.apiBase && ( - - Auto-filled - - )} - - - - - - - )} - /> -
- - {/* Ollama Quick Actions */} - {watchProvider === "OLLAMA" && ( -
- - -
- )} -
- - {/* Advanced Parameters */} - {showAdvanced && ( - <> - - - - - - - ( - - - - - - - )} - /> - - - - )} - - {/* System Instructions & Citations Section */} - - - - - - - {/* System Instructions */} - ( - -
- Instructions for the AI - {defaultInstructions && ( - - )} -
- -