Merge remote-tracking branch 'upstream/dev' into features/documents-injestion-layered-cached

This commit is contained in:
CREDO23 2026-06-14 11:30:33 +02:00
commit 32a6e54ce6
215 changed files with 9532 additions and 15405 deletions

View file

@ -4,7 +4,7 @@ Revision ID: 138
Revises: 137
Create Date: 2026-04-30
Add a single thread-level column to persist the Auto (Fastest) model pin:
Add a single thread-level column to persist the Auto model pin:
- pinned_llm_config_id: concrete resolved global LLM config id used for this
thread. NULL means "no pin; Auto will resolve on next turn".

View file

@ -15,6 +15,19 @@ down_revision: str | None = "157"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
PUBLICATION_NAME = "zero_publication"
TARGET_STATUS_LABELS = (
"pending",
"awaiting_brief",
"drafting",
"awaiting_review",
"rendering",
"ready",
"failed",
"cancelled",
)
LEGACY_STATUS_LABELS = ("pending", "generating", "ready", "failed")
def _drop_podcasts_from_publication() -> None:
"""Detach podcasts from zero_publication so status can be retyped.
@ -28,31 +41,103 @@ def _drop_podcasts_from_publication() -> None:
published = conn.execute(
sa.text(
"SELECT 1 FROM pg_publication_tables "
"WHERE pubname = 'zero_publication' "
"WHERE pubname = :publication "
"AND schemaname = current_schema() AND tablename = 'podcasts'"
)
),
{"publication": PUBLICATION_NAME},
).fetchone()
if published:
op.execute('ALTER PUBLICATION "zero_publication" DROP TABLE "podcasts";')
op.execute(f'ALTER PUBLICATION "{PUBLICATION_NAME}" DROP TABLE "podcasts";')
def upgrade() -> None:
_drop_podcasts_from_publication()
def _enum_labels(type_name: str) -> list[str] | None:
rows = (
op.get_bind()
.execute(
sa.text(
"SELECT e.enumlabel "
"FROM pg_type t "
"JOIN pg_namespace n ON n.oid = t.typnamespace "
"JOIN pg_enum e ON e.enumtypid = t.oid "
"WHERE n.nspname = current_schema() AND t.typname = :type_name "
"ORDER BY e.enumsortorder"
),
{"type_name": type_name},
)
.fetchall()
)
if not rows:
return None
return [str(row[0]) for row in rows]
# Retype the status enum by swapping in a fresh type and casting existing
# rows. The legacy transient value 'generating' maps onto 'rendering'.
op.execute("ALTER TYPE podcast_status RENAME TO podcast_status_old;")
def _column_type_name(table: str, column: str) -> str | None:
row = (
op.get_bind()
.execute(
sa.text(
"SELECT udt_name "
"FROM information_schema.columns "
"WHERE table_schema = current_schema() "
"AND table_name = :table AND column_name = :column"
),
{"table": table, "column": column},
)
.fetchone()
)
return str(row[0]) if row else None
def _ensure_status_enum(
*,
desired_labels: tuple[str, ...],
temporary_type: str,
create_sql: str,
alter_sql: str,
default_value: str,
) -> None:
current_labels = _enum_labels("podcast_status")
desired = list(desired_labels)
if current_labels != desired:
if current_labels is None:
if _enum_labels(temporary_type) is None:
raise RuntimeError("podcast_status enum is missing")
elif _enum_labels(temporary_type) is None:
op.execute(f"ALTER TYPE podcast_status RENAME TO {temporary_type};")
else:
raise RuntimeError(
"podcast_status and its temporary replacement both exist"
)
if _enum_labels("podcast_status") is None:
op.execute(create_sql)
if _enum_labels("podcast_status") != desired:
raise RuntimeError("podcast_status enum is not in the expected shape")
op.execute("ALTER TABLE podcasts ALTER COLUMN status DROP DEFAULT;")
if _column_type_name("podcasts", "status") != "podcast_status":
op.execute(alter_sql)
op.execute(
"""
f"ALTER TABLE podcasts ALTER COLUMN status SET DEFAULT '{default_value}';"
)
if _enum_labels(temporary_type) is not None:
op.execute(f"DROP TYPE {temporary_type};")
def _upgrade_status_enum() -> None:
_ensure_status_enum(
desired_labels=TARGET_STATUS_LABELS,
temporary_type="podcast_status_old",
create_sql="""
CREATE TYPE podcast_status AS ENUM (
'pending', 'awaiting_brief', 'drafting', 'awaiting_review',
'rendering', 'ready', 'failed', 'cancelled'
);
"""
)
op.execute("ALTER TABLE podcasts ALTER COLUMN status DROP DEFAULT;")
op.execute(
"""
""",
alter_sql="""
ALTER TABLE podcasts
ALTER COLUMN status TYPE podcast_status
USING (
@ -61,10 +146,43 @@ def upgrade() -> None:
ELSE status::text
END
)::podcast_status;
"""
""",
default_value="pending",
)
op.execute("ALTER TABLE podcasts ALTER COLUMN status SET DEFAULT 'pending';")
op.execute("DROP TYPE podcast_status_old;")
def _downgrade_status_enum() -> None:
_ensure_status_enum(
desired_labels=LEGACY_STATUS_LABELS,
temporary_type="podcast_status_new",
create_sql=(
"CREATE TYPE podcast_status AS ENUM "
"('pending', 'generating', 'ready', 'failed');"
),
alter_sql="""
ALTER TABLE podcasts
ALTER COLUMN status TYPE podcast_status
USING (
CASE status::text
WHEN 'awaiting_brief' THEN 'pending'
WHEN 'drafting' THEN 'generating'
WHEN 'awaiting_review' THEN 'generating'
WHEN 'rendering' THEN 'generating'
WHEN 'cancelled' THEN 'failed'
ELSE status::text
END
)::podcast_status;
""",
default_value="ready",
)
def upgrade() -> None:
_drop_podcasts_from_publication()
# Retype the status enum by swapping in a fresh type and casting existing
# rows. The legacy transient value 'generating' maps onto 'rendering'.
_upgrade_status_enum()
op.execute("ALTER TABLE podcasts ADD COLUMN IF NOT EXISTS source_content TEXT;")
op.execute("ALTER TABLE podcasts ADD COLUMN IF NOT EXISTS spec JSONB;")
@ -83,6 +201,8 @@ def upgrade() -> None:
def downgrade() -> None:
_drop_podcasts_from_publication()
op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS error;")
op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS duration_seconds;")
op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS storage_key;")
@ -92,27 +212,4 @@ def downgrade() -> None:
op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS source_content;")
# Collapse the expanded lifecycle back onto the original four values.
op.execute("ALTER TYPE podcast_status RENAME TO podcast_status_new;")
op.execute(
"CREATE TYPE podcast_status AS ENUM "
"('pending', 'generating', 'ready', 'failed');"
)
op.execute("ALTER TABLE podcasts ALTER COLUMN status DROP DEFAULT;")
op.execute(
"""
ALTER TABLE podcasts
ALTER COLUMN status TYPE podcast_status
USING (
CASE status::text
WHEN 'awaiting_brief' THEN 'pending'
WHEN 'drafting' THEN 'generating'
WHEN 'awaiting_review' THEN 'generating'
WHEN 'rendering' THEN 'generating'
WHEN 'cancelled' THEN 'failed'
ELSE status::text
END
)::podcast_status;
"""
)
op.execute("ALTER TABLE podcasts ALTER COLUMN status SET DEFAULT 'ready';")
op.execute("DROP TYPE podcast_status_new;")
_downgrade_status_enum()

View file

@ -0,0 +1,299 @@
"""add model connections
Revision ID: 160
Revises: 159
"""
from collections.abc import Sequence
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from alembic import op
revision: str = "160"
down_revision: str | None = "159"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
connection_scope = postgresql.ENUM(
"GLOBAL",
"SEARCH_SPACE",
"USER",
name="connectionscope",
create_type=False,
)
model_source = postgresql.ENUM(
"DISCOVERED",
"MANUAL",
name="modelsource",
create_type=False,
)
def _table_exists(table_name: str) -> bool:
return table_name in sa.inspect(op.get_bind()).get_table_names()
def _column_exists(table_name: str, column_name: str) -> bool:
if not _table_exists(table_name):
return False
return column_name in {
column["name"] for column in sa.inspect(op.get_bind()).get_columns(table_name)
}
def _index_exists(table_name: str, index_name: str) -> bool:
if not _table_exists(table_name):
return False
return index_name in {
index["name"] for index in sa.inspect(op.get_bind()).get_indexes(table_name)
}
def _create_index_if_missing(
index_name: str,
table_name: str,
columns: list[str],
) -> None:
if not _index_exists(table_name, index_name):
op.create_index(index_name, table_name, columns, unique=False)
def _add_searchspace_column_if_missing(
column_name: str,
*,
server_default: object | None = None,
) -> None:
if not _column_exists("searchspaces", column_name):
op.add_column(
"searchspaces",
sa.Column(
column_name,
sa.Integer(),
nullable=True,
server_default=server_default,
),
)
def _drop_column_if_exists(table_name: str, column_name: str) -> None:
if _column_exists(table_name, column_name):
op.drop_column(table_name, column_name)
def _drop_index_if_exists(table_name: str, index_name: str) -> None:
if _index_exists(table_name, index_name):
op.drop_index(index_name, table_name=table_name)
def upgrade() -> None:
bind = op.get_bind()
connection_scope.create(bind, checkfirst=True)
model_source.create(bind, checkfirst=True)
if _table_exists("connections"):
if _column_exists("connections", "litellm_provider") and not _column_exists(
"connections", "provider"
):
op.alter_column(
"connections",
"litellm_provider",
new_column_name="provider",
existing_type=sa.String(length=100),
existing_nullable=True,
)
op.alter_column(
"connections",
"provider",
existing_type=sa.String(length=100),
nullable=False,
)
elif _column_exists("connections", "native_provider") and not _column_exists(
"connections", "provider"
):
op.alter_column(
"connections",
"native_provider",
new_column_name="provider",
existing_type=sa.String(length=100),
existing_nullable=True,
)
op.alter_column(
"connections",
"provider",
existing_type=sa.String(length=100),
nullable=False,
)
elif not _column_exists("connections", "provider"):
op.add_column(
"connections",
sa.Column("provider", sa.String(length=100), nullable=False),
)
_drop_index_if_exists("connections", "ix_connections_protocol")
_drop_column_if_exists("connections", "protocol")
else:
op.create_table(
"connections",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("provider", sa.String(length=100), nullable=False),
sa.Column("base_url", sa.String(length=500), nullable=True),
sa.Column("api_key", sa.String(), nullable=True),
sa.Column(
"extra",
postgresql.JSONB(astext_type=sa.Text()),
server_default=sa.text("'{}'::jsonb"),
nullable=False,
),
sa.Column("scope", connection_scope, nullable=False),
sa.Column(
"enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False
),
sa.Column("search_space_id", sa.Integer(), nullable=True),
sa.Column("user_id", sa.UUID(), nullable=True),
sa.CheckConstraint(
"(scope = 'GLOBAL' AND search_space_id IS NULL AND user_id IS NULL) OR "
"(scope = 'SEARCH_SPACE' AND search_space_id IS NOT NULL AND user_id IS NOT NULL) OR "
"(scope = 'USER' AND user_id IS NOT NULL)",
name="ck_connections_scope_owner",
),
sa.ForeignKeyConstraint(
["search_space_id"], ["searchspaces.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
if _index_exists(
"connections", "ix_connections_native_provider"
) and not _index_exists("connections", "ix_connections_provider"):
op.execute(
"ALTER INDEX ix_connections_native_provider "
"RENAME TO ix_connections_provider"
)
if _index_exists(
"connections", "ix_connections_litellm_provider"
) and not _index_exists("connections", "ix_connections_provider"):
op.execute(
"ALTER INDEX ix_connections_litellm_provider "
"RENAME TO ix_connections_provider"
)
_create_index_if_missing("ix_connections_provider", "connections", ["provider"])
_create_index_if_missing("ix_connections_scope", "connections", ["scope"])
if not _table_exists("models"):
op.create_table(
"models",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("connection_id", sa.Integer(), nullable=False),
sa.Column("model_id", sa.String(length=255), nullable=False),
sa.Column("display_name", sa.String(length=255), nullable=True),
sa.Column(
"source",
model_source,
server_default="DISCOVERED",
nullable=False,
),
sa.Column("supports_chat", sa.Boolean(), nullable=True),
sa.Column("max_input_tokens", sa.Integer(), nullable=True),
sa.Column("supports_image_input", sa.Boolean(), nullable=True),
sa.Column("supports_tools", sa.Boolean(), nullable=True),
sa.Column("supports_image_generation", sa.Boolean(), nullable=True),
sa.Column(
"capabilities_override",
postgresql.JSONB(astext_type=sa.Text()),
server_default=sa.text("'{}'::jsonb"),
nullable=False,
),
sa.Column(
"enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False
),
sa.Column("billing_tier", sa.String(length=50), nullable=True),
sa.Column(
"catalog",
postgresql.JSONB(astext_type=sa.Text()),
server_default=sa.text("'{}'::jsonb"),
nullable=False,
),
sa.ForeignKeyConstraint(
["connection_id"], ["connections.id"], ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint(
"connection_id", "model_id", name="uq_models_connection_model_id"
),
)
else:
if not _column_exists("models", "supports_chat"):
op.add_column(
"models", sa.Column("supports_chat", sa.Boolean(), nullable=True)
)
if not _column_exists("models", "max_input_tokens"):
op.add_column(
"models", sa.Column("max_input_tokens", sa.Integer(), nullable=True)
)
if not _column_exists("models", "supports_image_input"):
op.add_column(
"models", sa.Column("supports_image_input", sa.Boolean(), nullable=True)
)
if not _column_exists("models", "supports_tools"):
op.add_column(
"models", sa.Column("supports_tools", sa.Boolean(), nullable=True)
)
if not _column_exists("models", "supports_image_generation"):
op.add_column(
"models",
sa.Column("supports_image_generation", sa.Boolean(), nullable=True),
)
_drop_column_if_exists("models", "capabilities")
_drop_column_if_exists("models", "capabilities_declared")
_drop_column_if_exists("models", "capabilities_verified")
_create_index_if_missing("ix_models_connection_id", "models", ["connection_id"])
_create_index_if_missing("ix_models_model_id", "models", ["model_id"])
_create_index_if_missing("ix_models_billing_tier", "models", ["billing_tier"])
_add_searchspace_column_if_missing("chat_model_id", server_default=sa.text("0"))
_add_searchspace_column_if_missing(
"image_gen_model_id", server_default=sa.text("0")
)
_add_searchspace_column_if_missing("vision_model_id", server_default=sa.text("0"))
for column_name in ("chat_model_id", "image_gen_model_id", "vision_model_id"):
op.alter_column(
"searchspaces",
column_name,
existing_type=sa.Integer(),
existing_nullable=True,
server_default=sa.text("0"),
)
op.execute(
"""
UPDATE searchspaces
SET
chat_model_id = COALESCE(chat_model_id, 0),
image_gen_model_id = COALESCE(image_gen_model_id, 0),
vision_model_id = COALESCE(vision_model_id, 0)
"""
)
op.execute("DROP TYPE IF EXISTS connectionprotocol")
def downgrade() -> None:
op.drop_column("searchspaces", "vision_model_id")
op.drop_column("searchspaces", "image_gen_model_id")
op.drop_column("searchspaces", "chat_model_id")
op.drop_index(op.f("ix_models_billing_tier"), table_name="models")
op.drop_index("ix_models_model_id", table_name="models")
op.drop_index(op.f("ix_models_connection_id"), table_name="models")
op.drop_table("models")
op.drop_index(op.f("ix_connections_scope"), table_name="connections")
op.drop_index(op.f("ix_connections_provider"), table_name="connections")
op.drop_table("connections")
bind = op.get_bind()
model_source.drop(bind, checkfirst=True)
connection_scope.drop(bind, checkfirst=True)

View file

@ -0,0 +1,270 @@
"""remove legacy model config tables
Revision ID: 161
Revises: 160
"""
from collections.abc import Sequence
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from sqlalchemy.types import TypeEngine
from alembic import op
revision: str = "161"
down_revision: str | None = "160"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
litellm_provider = postgresql.ENUM(
"OPENAI",
"ANTHROPIC",
"GOOGLE",
"AZURE_OPENAI",
"BEDROCK",
"VERTEX_AI",
"GROQ",
"COHERE",
"MISTRAL",
"DEEPSEEK",
"XAI",
"OPENROUTER",
"TOGETHER_AI",
"FIREWORKS_AI",
"REPLICATE",
"PERPLEXITY",
"OLLAMA",
"ALIBABA_QWEN",
"MOONSHOT",
"ZHIPU",
"ANYSCALE",
"DEEPINFRA",
"CEREBRAS",
"SAMBANOVA",
"AI21",
"CLOUDFLARE",
"DATABRICKS",
"COMETAPI",
"HUGGINGFACE",
"GITHUB_MODELS",
"MINIMAX",
"CUSTOM",
name="litellmprovider",
create_type=False,
)
image_gen_provider = postgresql.ENUM(
"OPENAI",
"AZURE_OPENAI",
"GOOGLE",
"VERTEX_AI",
"BEDROCK",
"RECRAFT",
"OPENROUTER",
"XINFERENCE",
"NSCALE",
name="imagegenprovider",
create_type=False,
)
vision_provider = postgresql.ENUM(
"OPENAI",
"ANTHROPIC",
"GOOGLE",
"AZURE_OPENAI",
"VERTEX_AI",
"BEDROCK",
"XAI",
"OPENROUTER",
"OLLAMA",
"GROQ",
"TOGETHER_AI",
"FIREWORKS_AI",
"DEEPSEEK",
"MISTRAL",
"CUSTOM",
name="visionprovider",
create_type=False,
)
def _table_exists(table_name: str) -> bool:
return table_name in sa.inspect(op.get_bind()).get_table_names()
def _column_exists(table_name: str, column_name: str) -> bool:
if not _table_exists(table_name):
return False
return column_name in {
column["name"] for column in sa.inspect(op.get_bind()).get_columns(table_name)
}
def _drop_column_if_exists(table_name: str, column_name: str) -> None:
if _column_exists(table_name, column_name):
op.drop_column(table_name, column_name)
def _rename_column_if_exists(
table_name: str,
old_column_name: str,
new_column_name: str,
*,
existing_type: TypeEngine,
existing_nullable: bool = True,
) -> None:
if _column_exists(table_name, old_column_name) and not _column_exists(
table_name, new_column_name
):
op.alter_column(
table_name,
old_column_name,
new_column_name=new_column_name,
existing_type=existing_type,
existing_nullable=existing_nullable,
)
def upgrade() -> None:
for table_name in (
"new_llm_configs",
"vision_llm_configs",
"image_generation_configs",
):
if _table_exists(table_name):
op.drop_table(table_name)
_drop_column_if_exists("searchspaces", "agent_llm_id")
_drop_column_if_exists("searchspaces", "image_generation_config_id")
_drop_column_if_exists("searchspaces", "vision_llm_config_id")
_rename_column_if_exists(
"image_generations",
"image_generation_config_id",
"image_gen_model_id",
existing_type=sa.Integer(),
)
op.execute("DROP TYPE IF EXISTS litellmprovider")
op.execute("DROP TYPE IF EXISTS imagegenprovider")
op.execute("DROP TYPE IF EXISTS visionprovider")
def downgrade() -> None:
bind = op.get_bind()
litellm_provider.create(bind, checkfirst=True)
image_gen_provider.create(bind, checkfirst=True)
vision_provider.create(bind, checkfirst=True)
_rename_column_if_exists(
"image_generations",
"image_gen_model_id",
"image_generation_config_id",
existing_type=sa.Integer(),
)
if _table_exists("searchspaces"):
if not _column_exists("searchspaces", "agent_llm_id"):
op.add_column(
"searchspaces",
sa.Column("agent_llm_id", sa.Integer(), nullable=True),
)
if not _column_exists("searchspaces", "image_generation_config_id"):
op.add_column(
"searchspaces",
sa.Column("image_generation_config_id", sa.Integer(), nullable=True),
)
if not _column_exists("searchspaces", "vision_llm_config_id"):
op.add_column(
"searchspaces",
sa.Column("vision_llm_config_id", sa.Integer(), nullable=True),
)
if not _table_exists("image_generation_configs"):
op.create_table(
"image_generation_configs",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("name", sa.String(length=100), nullable=False),
sa.Column("description", sa.String(length=500), nullable=True),
sa.Column("provider", image_gen_provider, nullable=False),
sa.Column("custom_provider", sa.String(length=100), nullable=True),
sa.Column("model_name", sa.String(length=100), nullable=False),
sa.Column("api_key", sa.String(), nullable=False),
sa.Column("api_base", sa.String(length=500), nullable=True),
sa.Column("api_version", sa.String(length=50), nullable=True),
sa.Column("litellm_params", sa.JSON(), nullable=True),
sa.Column("search_space_id", sa.Integer(), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.ForeignKeyConstraint(
["search_space_id"], ["searchspaces.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_image_generation_configs_name"),
"image_generation_configs",
["name"],
unique=False,
)
if not _table_exists("vision_llm_configs"):
op.create_table(
"vision_llm_configs",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("name", sa.String(length=100), nullable=False),
sa.Column("description", sa.String(length=500), nullable=True),
sa.Column("provider", vision_provider, nullable=False),
sa.Column("custom_provider", sa.String(length=100), nullable=True),
sa.Column("model_name", sa.String(length=100), nullable=False),
sa.Column("api_key", sa.String(), nullable=False),
sa.Column("api_base", sa.String(length=500), nullable=True),
sa.Column("api_version", sa.String(length=50), nullable=True),
sa.Column("litellm_params", sa.JSON(), nullable=True),
sa.Column("search_space_id", sa.Integer(), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.ForeignKeyConstraint(
["search_space_id"], ["searchspaces.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_vision_llm_configs_name"),
"vision_llm_configs",
["name"],
unique=False,
)
if not _table_exists("new_llm_configs"):
op.create_table(
"new_llm_configs",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("name", sa.String(length=100), nullable=False),
sa.Column("description", sa.String(length=500), nullable=True),
sa.Column("provider", litellm_provider, nullable=False),
sa.Column("custom_provider", sa.String(length=100), nullable=True),
sa.Column("model_name", sa.String(length=100), nullable=False),
sa.Column("api_key", sa.String(), nullable=False),
sa.Column("api_base", sa.String(length=500), nullable=True),
sa.Column("litellm_params", sa.JSON(), nullable=True),
sa.Column("system_instructions", sa.Text(), nullable=False),
sa.Column("use_default_system_instructions", sa.Boolean(), nullable=False),
sa.Column("citations_enabled", sa.Boolean(), nullable=False),
sa.Column("search_space_id", sa.Integer(), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.ForeignKeyConstraint(
["search_space_id"], ["searchspaces.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_new_llm_configs_name"),
"new_llm_configs",
["name"],
unique=False,
)

View file

@ -1,15 +1,15 @@
"""add etl_cache_parses table for content-addressed parse reuse
Revision ID: 160
Revises: 159
Revision ID: 162
Revises: 161
"""
from collections.abc import Sequence
from alembic import op
revision: str = "160"
down_revision: str | None = "159"
revision: str = "162"
down_revision: str | None = "161"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None

View file

@ -1,15 +1,15 @@
"""add embedding_cache_sets table for content-addressed embedding reuse
Revision ID: 161
Revises: 160
Revision ID: 163
Revises: 162
"""
from collections.abc import Sequence
from alembic import op
revision: str = "161"
down_revision: str | None = "160"
revision: str = "163"
down_revision: str | None = "162"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None

View file

@ -3,16 +3,16 @@
Incremental re-indexing keeps unchanged chunk rows, so auto-increment ids no
longer reflect document order. Backfill preserves the historical id ordering.
Revision ID: 162
Revises: 161
Revision ID: 164
Revises: 163
"""
from collections.abc import Sequence
from alembic import op
revision: str = "162"
down_revision: str | None = "161"
revision: str = "164"
down_revision: str | None = "163"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None

View file

@ -57,7 +57,7 @@ async def build_agent_with_cache(
mcp_tools_by_agent: dict[str, list[BaseTool]],
disabled_tools: list[str] | None,
config_id: str | None,
image_generation_config_id_override: int | None = None,
image_gen_model_id_override: int | None = None,
) -> Any:
"""Compile the multi-agent graph, serving from cache when key components are stable."""
@ -121,7 +121,7 @@ async def build_agent_with_cache(
# Bound into the generate_image subagent tool at construction time, so it
# must key the compiled-agent cache to avoid leaking one automation's
# image model into another with the same config_id/search_space.
image_generation_config_id_override,
image_gen_model_id_override,
)
return await get_cache().get_or_build(cache_key, builder=_build)

View file

@ -72,11 +72,11 @@ async def create_multi_agent_chat_deep_agent(
mentioned_document_ids: list[int] | None = None,
anon_session_id: str | None = None,
filesystem_selection: FilesystemSelection | None = None,
image_generation_config_id: int | None = None,
image_gen_model_id: int | None = None,
):
"""Deep agent with SurfSense tools/middleware; registry route subagents behind ``task`` when enabled.
``image_generation_config_id`` overrides the search space's image model for
``image_gen_model_id`` overrides the search space's image model for
this invocation (used by automations to run on their captured model). When
``None``, the ``generate_image`` tool resolves the live search-space pref.
"""
@ -147,7 +147,7 @@ async def create_multi_agent_chat_deep_agent(
"llm": llm,
# Per-invocation image model override (automations run on their captured
# model). Reaches the generate_image subagent tool via subagent_dependencies.
"image_generation_config_id_override": image_generation_config_id,
"image_gen_model_id_override": image_gen_model_id,
}
_t0 = time.perf_counter()
@ -303,7 +303,7 @@ async def create_multi_agent_chat_deep_agent(
mcp_tools_by_agent=mcp_tools_by_agent,
disabled_tools=disabled_tools,
config_id=config_id,
image_generation_config_id_override=image_generation_config_id,
image_gen_model_id_override=image_gen_model_id,
)
_perf_log.info(
"[create_agent] Middleware stack + graph compiled in %.3fs",

View file

@ -10,70 +10,53 @@ from langgraph.types import Command
from litellm import aimage_generation
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.agents.chat.multi_agent_chat.shared.receipts.command import with_receipt
from app.agents.chat.multi_agent_chat.shared.receipts.receipt import make_receipt
from app.config import config
from app.db import (
ImageGeneration,
ImageGenerationConfig,
Model,
SearchSpace,
shielded_async_session,
)
from app.services.auto_model_pin_service import (
auto_model_candidates,
choose_auto_model_candidate,
)
from app.services.image_gen_router_service import (
IMAGE_GEN_AUTO_MODE_ID,
ImageGenRouterService,
is_image_gen_auto_mode,
)
from app.services.provider_api_base import resolve_api_base
from app.services.model_capabilities import has_capability
from app.services.model_resolver import to_litellm
from app.utils.signed_image_urls import generate_image_token
logger = logging.getLogger(__name__)
# Provider mapping (same as routes)
_PROVIDER_MAP = {
"OPENAI": "openai",
"AZURE_OPENAI": "azure",
"GOOGLE": "gemini",
"VERTEX_AI": "vertex_ai",
"BEDROCK": "bedrock",
"RECRAFT": "recraft",
"OPENROUTER": "openrouter",
"XINFERENCE": "xinference",
"NSCALE": "nscale",
}
def _get_global_model(model_id: int) -> dict | None:
return next((m for m in config.GLOBAL_MODELS if m.get("id") == model_id), None)
def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str:
if custom_provider:
return custom_provider
return _PROVIDER_MAP.get(provider.upper(), provider.lower())
def _build_model_string(
provider: str, model_name: str, custom_provider: str | None
) -> str:
return f"{_resolve_provider_prefix(provider, custom_provider)}/{model_name}"
def _get_global_image_gen_config(config_id: int) -> dict | None:
"""Get a global image gen config by negative ID."""
for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS:
if cfg.get("id") == config_id:
return cfg
return None
def _get_global_connection(connection_id: int) -> dict | None:
return next(
(c for c in config.GLOBAL_CONNECTIONS if c.get("id") == connection_id),
None,
)
def create_generate_image_tool(
search_space_id: int,
db_session: AsyncSession,
image_generation_config_id_override: int | None = None,
image_gen_model_id_override: int | None = None,
):
"""Create ``generate_image`` with bound search space; DB work uses a per-call session.
``image_generation_config_id_override``: when set (automations running on a
captured model), use this config id instead of reading the search space's
live ``image_generation_config_id``.
``image_gen_model_id_override``: when set (automations running on a
captured model), use this model id instead of reading the search space's
live ``image_gen_model_id``.
"""
del db_session # tool uses a fresh per-call session instead
@ -118,26 +101,23 @@ def create_generate_image_tool(
# task's session is shared across every tool; without isolation,
# autoflushes from a concurrent writer poison this tool too.
async with shielded_async_session() as session:
if image_generation_config_id_override is not None:
result = await session.execute(
select(SearchSpace).filter(SearchSpace.id == search_space_id)
)
search_space = result.scalars().first()
if not search_space:
return _failed(
{"error": "Search space not found"},
error="Search space not found",
)
if image_gen_model_id_override is not None:
# Automation run: use the captured image model, insulated from
# later search-space changes. No search-space read needed.
config_id = (
image_generation_config_id_override or IMAGE_GEN_AUTO_MODE_ID
)
config_id = image_gen_model_id_override or IMAGE_GEN_AUTO_MODE_ID
else:
result = await session.execute(
select(SearchSpace).filter(SearchSpace.id == search_space_id)
)
search_space = result.scalars().first()
if not search_space:
return _failed(
{"error": "Search space not found"},
error="Search space not found",
)
config_id = (
search_space.image_generation_config_id
or IMAGE_GEN_AUTO_MODE_ID
search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID
)
# size/quality/style are intentionally omitted: valid values
@ -147,73 +127,82 @@ def create_generate_image_tool(
gen_kwargs["n"] = n
if is_image_gen_auto_mode(config_id):
if not ImageGenRouterService.is_initialized():
candidates = await auto_model_candidates(
session,
search_space_id=search_space_id,
user_id=search_space.user_id,
capability="image_gen",
)
if not candidates:
err = (
"No image generation models configured. "
"No image generation models available. "
"Please add an image model in Settings > Image Models."
)
return _failed({"error": err}, error=err)
response = await ImageGenRouterService.aimage_generation(
prompt=prompt, model="auto", **gen_kwargs
config_id = int(
choose_auto_model_candidate(candidates, search_space_id)["id"]
)
elif config_id < 0:
cfg = _get_global_image_gen_config(config_id)
if not cfg:
err = f"Image generation config {config_id} not found"
if config_id < 0:
global_model = _get_global_model(config_id)
if not global_model or not has_capability(
global_model, "image_gen"
):
err = f"Image generation model {config_id} not found"
return _failed({"error": err}, error=err)
global_connection = _get_global_connection(
global_model["connection_id"]
)
if not global_connection:
err = f"Image generation connection for model {config_id} not found"
return _failed({"error": err}, error=err)
provider_prefix = _resolve_provider_prefix(
cfg.get("provider", ""), cfg.get("custom_provider")
model_string, resolved_kwargs = to_litellm(
global_connection,
global_model["model_id"],
)
model_string = f"{provider_prefix}/{cfg['model_name']}"
gen_kwargs["api_key"] = cfg.get("api_key")
# Defense-in-depth: an empty ``api_base`` must not fall
# through to LiteLLM's global ``api_base`` (e.g. Azure).
api_base = resolve_api_base(
provider=cfg.get("provider"),
provider_prefix=provider_prefix,
config_api_base=cfg.get("api_base"),
)
if api_base:
gen_kwargs["api_base"] = api_base
if cfg.get("api_version"):
gen_kwargs["api_version"] = cfg["api_version"]
if cfg.get("litellm_params"):
gen_kwargs.update(cfg["litellm_params"])
gen_kwargs.update(resolved_kwargs)
response = await aimage_generation(
prompt=prompt, model=model_string, **gen_kwargs
)
else:
# Positive ID = user-created ImageGenerationConfig
# Positive ID = Model + Connection
cfg_result = await session.execute(
select(ImageGenerationConfig).filter(
ImageGenerationConfig.id == config_id
)
select(Model)
.options(selectinload(Model.connection))
.filter(Model.id == config_id, Model.enabled.is_(True))
)
db_cfg = cfg_result.scalars().first()
if not db_cfg:
err = f"Image generation config {config_id} not found"
db_model = cfg_result.scalars().first()
if (
not db_model
or not db_model.connection
or not db_model.connection.enabled
):
err = f"Image generation model {config_id} not found"
return _failed({"error": err}, error=err)
conn = db_model.connection
if (
conn.search_space_id is not None
and conn.search_space_id != search_space_id
):
err = f"Image generation model {config_id} not found"
return _failed({"error": err}, error=err)
if (
conn.user_id is not None
and conn.user_id != search_space.user_id
):
err = f"Image generation model {config_id} not found"
return _failed({"error": err}, error=err)
if not has_capability(db_model, "image_gen"):
err = f"Model {config_id} is not image-generation capable"
return _failed({"error": err}, error=err)
provider_prefix = _resolve_provider_prefix(
db_cfg.provider.value, db_cfg.custom_provider
model_string, resolved_kwargs = to_litellm(
db_model.connection,
db_model.model_id,
)
model_string = f"{provider_prefix}/{db_cfg.model_name}"
gen_kwargs["api_key"] = db_cfg.api_key
# Defense-in-depth: an empty ``api_base`` must not fall
# through to LiteLLM's global ``api_base`` (e.g. Azure).
api_base = resolve_api_base(
provider=db_cfg.provider.value,
provider_prefix=provider_prefix,
config_api_base=db_cfg.api_base,
)
if api_base:
gen_kwargs["api_base"] = api_base
if db_cfg.api_version:
gen_kwargs["api_version"] = db_cfg.api_version
if db_cfg.litellm_params:
gen_kwargs.update(db_cfg.litellm_params)
gen_kwargs.update(resolved_kwargs)
response = await aimage_generation(
prompt=prompt, model=model_string, **gen_kwargs
@ -230,7 +219,7 @@ def create_generate_image_tool(
prompt=prompt,
model=getattr(response, "_hidden_params", {}).get("model"),
n=n,
image_generation_config_id=config_id,
image_gen_model_id=config_id,
response_data=response_dict,
search_space_id=search_space_id,
access_token=access_token,

View file

@ -51,8 +51,6 @@ def load_tools(
create_generate_image_tool(
search_space_id=d["search_space_id"],
db_session=d["db_session"],
image_generation_config_id_override=d.get(
"image_generation_config_id_override"
),
image_gen_model_id_override=d.get("image_gen_model_id_override"),
),
]

View file

@ -2,9 +2,9 @@
LLM configuration utilities for SurfSense agents.
This module provides functions for loading LLM configurations from:
1. Auto mode (ID 0) - Uses LiteLLM Router for load balancing
1. Auto mode (ID 0) - Resolved by callers to a concrete model-connection model
2. YAML files (global configs with negative IDs)
3. Database NewLLMConfig table (user-created configs with positive IDs)
3. Database model-connections table (user-created configs with positive IDs)
It also provides utilities for creating ChatLiteLLM instances and
managing prompt configurations.
@ -24,8 +24,6 @@ from langchain_core.messages import AIMessage, BaseMessage
from langchain_core.outputs import ChatGenerationChunk, ChatResult
from langchain_litellm import ChatLiteLLM
from litellm import get_model_info
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.chat.runtime.prompt_caching import (
apply_litellm_prompt_caching,
@ -33,10 +31,7 @@ from app.agents.chat.runtime.prompt_caching import (
from app.services.llm_router_service import (
AUTO_MODE_ID,
ChatLiteLLMRouter,
LLMRouterService,
_sanitize_content,
get_auto_mode_llm,
is_auto_mode,
)
@ -51,16 +46,19 @@ def _sanitize_messages(messages: list[BaseMessage]) -> list[BaseMessage]:
reject the blank text. The OpenAI spec says ``content`` should be
``null`` when an assistant message only carries tool calls.
"""
sanitized: list[BaseMessage] = []
for msg in messages:
if isinstance(msg.content, list):
msg.content = _sanitize_content(msg.content)
next_msg = msg.model_copy(deep=True)
if isinstance(next_msg.content, list):
next_msg.content = _sanitize_content(next_msg.content)
if (
isinstance(msg, AIMessage)
and (not msg.content or msg.content == "")
and getattr(msg, "tool_calls", None)
isinstance(next_msg, AIMessage)
and (not next_msg.content or next_msg.content == "")
and getattr(next_msg, "tool_calls", None)
):
msg.content = None # type: ignore[assignment]
return messages
next_msg.content = None # type: ignore[assignment]
sanitized.append(next_msg)
return sanitized
class SanitizedChatLiteLLM(ChatLiteLLM):
@ -91,13 +89,21 @@ class SanitizedChatLiteLLM(ChatLiteLLM):
):
yield chunk
# Re-exported under the historical name ``PROVIDER_MAP``. Source of truth lives
# in provider_capabilities so the YAML loader can resolve prefixes during
# app.config init without importing the agent/tools tree.
from app.services.provider_capabilities import ( # noqa: E402
_PROVIDER_PREFIX_MAP as PROVIDER_MAP,
)
async def _agenerate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: AsyncCallbackManagerForLLMRun | None = None,
stream: bool | None = None,
**kwargs: Any,
) -> ChatResult:
return await super()._agenerate(
_sanitize_messages(messages),
stop=stop,
run_manager=run_manager,
stream=stream,
**kwargs,
)
def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None:
@ -121,8 +127,9 @@ class AgentConfig:
"""
Complete configuration for the SurfSense agent.
This combines LLM settings with prompt configuration from NewLLMConfig.
Supports Auto mode (ID 0) which uses LiteLLM Router for load balancing.
This combines resolved model settings with prompt configuration.
Supports Auto mode metadata (ID 0). Runtime callers must resolve Auto to
a concrete global or BYOK model before constructing ChatLiteLLM.
"""
# LLM Model Settings
@ -170,7 +177,7 @@ class AgentConfig:
use_default_system_instructions=True,
citations_enabled=True,
config_id=AUTO_MODE_ID,
config_name="Auto (Fastest)",
config_name="Auto",
is_auto_mode=True,
billing_tier="free",
is_premium=False,
@ -181,64 +188,21 @@ class AgentConfig:
supports_image_input=True,
)
@classmethod
def from_new_llm_config(cls, config) -> "AgentConfig":
"""Build an AgentConfig from a NewLLMConfig database model."""
# Lazy import: keeps provider_capabilities (and litellm) out of init order.
from app.services.provider_capabilities import derive_supports_image_input
provider_value = (
config.provider.value
if hasattr(config.provider, "value")
else str(config.provider)
)
litellm_params = config.litellm_params or {}
base_model = (
litellm_params.get("base_model")
if isinstance(litellm_params, dict)
else None
)
return cls(
provider=provider_value,
model_name=config.model_name,
api_key=config.api_key,
api_base=config.api_base,
custom_provider=config.custom_provider,
litellm_params=config.litellm_params,
system_instructions=config.system_instructions,
use_default_system_instructions=config.use_default_system_instructions,
citations_enabled=config.citations_enabled,
config_id=config.id,
config_name=config.name,
is_auto_mode=False,
billing_tier="free",
is_premium=False,
anonymous_enabled=False,
quota_reserve_tokens=None,
# BYOK rows have no curated flag; ask LiteLLM (default-allow on
# unknown). The streaming safety net still blocks explicit text-only.
supports_image_input=derive_supports_image_input(
provider=provider_value,
model_name=config.model_name,
base_model=base_model,
custom_provider=config.custom_provider,
),
)
@classmethod
def from_yaml_config(cls, yaml_config: dict) -> "AgentConfig":
"""Build an AgentConfig from a YAML configuration dictionary.
Supports the same prompt fields as NewLLMConfig (system_instructions,
use_default_system_instructions, citations_enabled).
Supports prompt fields such as system_instructions,
use_default_system_instructions, and citations_enabled.
"""
# Lazy import: keeps provider_capabilities (and litellm) out of init order.
from app.services.provider_capabilities import derive_supports_image_input
system_instructions = yaml_config.get("system_instructions", "")
provider = yaml_config.get("provider", "").upper()
provider = yaml_config.get("provider") or yaml_config.get(
"litellm_provider", ""
)
model_name = yaml_config.get("model_name", "")
custom_provider = yaml_config.get("custom_provider")
litellm_params = yaml_config.get("litellm_params") or {}
@ -324,93 +288,15 @@ def load_global_llm_config_by_id(llm_config_id: int) -> dict | None:
return load_llm_config_from_yaml(llm_config_id)
async def load_new_llm_config_from_db(
session: AsyncSession,
config_id: int,
) -> "AgentConfig | None":
"""Load a NewLLMConfig from the database by ID."""
from app.db import NewLLMConfig
try:
result = await session.execute(
select(NewLLMConfig).filter(NewLLMConfig.id == config_id)
)
config = result.scalars().first()
if not config:
print(f"Error: NewLLMConfig with id {config_id} not found")
return None
return AgentConfig.from_new_llm_config(config)
except Exception as e:
print(f"Error loading NewLLMConfig from database: {e}")
return None
async def load_agent_llm_config_for_search_space(
session: AsyncSession,
search_space_id: int,
) -> "AgentConfig | None":
"""Load the agent LLM config for a search space via its agent_llm_id.
Positive id -> DB; negative -> YAML; None -> first global config (-1).
"""
from app.db import SearchSpace
try:
result = await session.execute(
select(SearchSpace).filter(SearchSpace.id == search_space_id)
)
search_space = result.scalars().first()
if not search_space:
print(f"Error: SearchSpace with id {search_space_id} not found")
return None
config_id = (
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
)
return await load_agent_config(session, config_id, search_space_id)
except Exception as e:
print(f"Error loading agent LLM config for search space {search_space_id}: {e}")
return None
async def load_agent_config(
session: AsyncSession,
config_id: int,
search_space_id: int | None = None,
) -> "AgentConfig | None":
"""Main config loader: id 0 -> Auto mode; negative -> YAML; positive -> DB."""
if is_auto_mode(config_id):
if not LLMRouterService.is_initialized():
print("Error: Auto mode requested but LLM Router not initialized")
return None
return AgentConfig.from_auto_mode()
if config_id < 0:
# In-memory covers static YAML + dynamic OpenRouter configs.
from app.config import config as app_config
for cfg in app_config.GLOBAL_LLM_CONFIGS:
if cfg.get("id") == config_id:
return AgentConfig.from_yaml_config(cfg)
yaml_config = load_llm_config_from_yaml(config_id)
if yaml_config:
return AgentConfig.from_yaml_config(yaml_config)
return None
else:
return await load_new_llm_config_from_db(session, config_id)
def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
"""Create a ChatLiteLLM instance from a global LLM config dictionary."""
if llm_config.get("custom_provider"):
model_string = f"{llm_config['custom_provider']}/{llm_config['model_name']}"
else:
provider = llm_config.get("provider", "").upper()
provider_prefix = PROVIDER_MAP.get(provider, provider.lower())
model_string = f"{provider_prefix}/{llm_config['model_name']}"
provider = llm_config.get("provider") or llm_config.get(
"litellm_provider", "openai"
)
model_string = f"{provider}/{llm_config['model_name']}"
litellm_kwargs = {
"model": model_string,
@ -433,29 +319,17 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
def create_chat_litellm_from_agent_config(
agent_config: AgentConfig,
) -> ChatLiteLLM | ChatLiteLLMRouter | None:
"""Create a ChatLiteLLM (or, for Auto mode, a load-balancing router) from config."""
"""Create a ChatLiteLLM from an already resolved concrete model config."""
if agent_config.is_auto_mode:
if not LLMRouterService.is_initialized():
print("Error: Auto mode requested but LLM Router not initialized")
return None
try:
router_llm = get_auto_mode_llm()
if router_llm is not None:
# Universal injection points only: auto-mode fans out across
# providers, so provider-specific kwargs have no known target.
apply_litellm_prompt_caching(router_llm, agent_config=agent_config)
return router_llm
except Exception as e:
print(f"Error creating ChatLiteLLMRouter: {e}")
return None
print(
"Error: Auto mode must be resolved to a concrete model before LLM creation"
)
return None
if agent_config.custom_provider:
model_string = f"{agent_config.custom_provider}/{agent_config.model_name}"
else:
provider_prefix = PROVIDER_MAP.get(
agent_config.provider, agent_config.provider.lower()
)
model_string = f"{provider_prefix}/{agent_config.model_name}"
model_string = f"{agent_config.provider}/{agent_config.model_name}"
litellm_kwargs = {
"model": model_string,

View file

@ -33,7 +33,6 @@ from app.config import (
initialize_llm_router,
initialize_openrouter_integration,
initialize_pricing_registration,
initialize_vision_llm_router,
)
from app.db import User, create_db_and_tables, get_async_session
from app.exceptions import GENERIC_5XX_MESSAGE, ISSUES_URL, SurfSenseError
@ -622,7 +621,6 @@ async def lifespan(app: FastAPI):
initialize_pricing_registration()
initialize_llm_router()
initialize_image_gen_router()
initialize_vision_llm_router()
# Phase 1.7 — JIT warmup. Bounded so a stuck warmup never delays
# worker readiness. ``shield`` so Uvicorn cancelling startup

View file

@ -39,31 +39,31 @@ async def build_dependencies(
*,
session: AsyncSession,
search_space_id: int,
agent_llm_id: int | None = None,
image_generation_config_id: int | None = None,
vision_llm_config_id: int | None = None,
chat_model_id: int | None = None,
image_gen_model_id: int | None = None,
vision_model_id: int | None = None,
) -> AgentDependencies:
"""Load the LLM bundle, connector service, and a per-invoke in-memory checkpointer.
Resolves the agent LLM from the automation's *captured* model snapshot
(``agent_llm_id``) so runs are insulated from later chat/search-space model
Resolves the chat model from the automation's *captured* model snapshot
(``chat_model_id``) so runs are insulated from later chat/search-space model
changes. The model policy is enforced here as a runtime backstop: a captured
model that is no longer billable (e.g. a premium global config was removed)
fails the run clearly instead of silently consuming a free model.
When ``agent_llm_id`` is ``None`` (no captured snapshot defensive fallback),
fall back to the live search space's ``agent_llm_id`` and validate that.
When ``chat_model_id`` is ``None`` (no captured snapshot defensive fallback),
fall back to the live search space's ``chat_model_id`` and validate that.
"""
if agent_llm_id is not None:
if chat_model_id is not None:
try:
assert_models_billable(
agent_llm_id=agent_llm_id,
image_generation_config_id=image_generation_config_id,
vision_llm_config_id=vision_llm_config_id,
chat_model_id=chat_model_id,
image_gen_model_id=image_gen_model_id,
vision_model_id=vision_model_id,
)
except AutomationModelPolicyError as exc:
raise DependencyError(str(exc)) from exc
resolved_agent_llm_id = agent_llm_id or 0
resolved_chat_model_id = chat_model_id or 0
else:
search_space = await session.get(SearchSpace, search_space_id)
if search_space is None:
@ -72,15 +72,15 @@ async def build_dependencies(
assert_automation_models_billable(search_space)
except AutomationModelPolicyError as exc:
raise DependencyError(str(exc)) from exc
resolved_agent_llm_id = search_space.agent_llm_id or 0
resolved_chat_model_id = search_space.chat_model_id or 0
llm, agent_config, err = await load_llm_bundle(
session,
config_id=resolved_agent_llm_id,
config_id=resolved_chat_model_id,
search_space_id=search_space_id,
)
if err is not None or llm is None:
raise DependencyError(err or "failed to load agent LLM config")
raise DependencyError(err or "failed to load chat model config")
connector_service, firecrawl_api_key = await setup_connector_and_firecrawl(
session, search_space_id=search_space_id

View file

@ -150,9 +150,9 @@ async def run_agent_task(
deps = await build_dependencies(
session=agent_session,
search_space_id=ctx.search_space_id,
agent_llm_id=ctx.agent_llm_id,
image_generation_config_id=ctx.image_generation_config_id,
vision_llm_config_id=ctx.vision_llm_config_id,
chat_model_id=ctx.chat_model_id,
image_gen_model_id=ctx.image_gen_model_id,
vision_model_id=ctx.vision_model_id,
)
agent = await create_multi_agent_chat_deep_agent(
@ -167,7 +167,7 @@ async def run_agent_task(
firecrawl_api_key=deps.firecrawl_api_key,
thread_visibility=ChatVisibility.PRIVATE,
mentioned_document_ids=mentioned_document_ids,
image_generation_config_id=ctx.image_generation_config_id,
image_gen_model_id=ctx.image_gen_model_id,
)
agent_query, runtime_context = await _resolve_mention_context(

View file

@ -23,9 +23,9 @@ class ActionContext:
# Captured model snapshot from the automation definition (``definition.models``),
# resolved per run instead of the live search space. ``None`` falls back to the
# search space's current prefs (defensive; should not happen post-capture).
agent_llm_id: int | None = None
image_generation_config_id: int | None = None
vision_llm_config_id: int | None = None
chat_model_id: int | None = None
image_gen_model_id: int | None = None
vision_model_id: int | None = None
ActionHandler = Callable[[dict[str, Any]], Awaitable[Any]]

View file

@ -132,9 +132,7 @@ def _build_action_ctx(
step_id=step.step_id,
search_space_id=automation.search_space_id,
creator_user_id=automation.created_by_user_id,
agent_llm_id=models.agent_llm_id if models else None,
image_generation_config_id=(
models.image_generation_config_id if models else None
),
vision_llm_config_id=models.vision_llm_config_id if models else None,
chat_model_id=models.chat_model_id if models else None,
image_gen_model_id=models.image_gen_model_id if models else None,
vision_model_id=models.vision_model_id if models else None,
)

View file

@ -14,16 +14,16 @@ from .trigger_spec import TriggerSpec
class AutomationModels(BaseModel):
"""Captured model profile for an automation.
Snapshotted from the search space's preferences at create time so runs are
insulated from later chat/search-space model changes. Config-id conventions
Snapshotted from the search space's model roles at create time so runs are
insulated from later chat/search-space model changes. Model-id conventions
match the shared scheme (``0`` Auto, ``< 0`` global, ``> 0`` BYOK).
"""
model_config = ConfigDict(extra="forbid")
agent_llm_id: int = 0
image_generation_config_id: int = 0
vision_llm_config_id: int = 0
chat_model_id: int = 0
image_gen_model_id: int = 0
vision_model_id: int = 0
class AutomationDefinition(BaseModel):

View file

@ -57,9 +57,9 @@ class AutomationService:
else:
search_space = await self._assert_models_billable(payload.search_space_id)
payload.definition.models = AutomationModels(
agent_llm_id=search_space.agent_llm_id or 0,
image_generation_config_id=search_space.image_generation_config_id or 0,
vision_llm_config_id=search_space.vision_llm_config_id or 0,
chat_model_id=search_space.chat_model_id or 0,
image_gen_model_id=search_space.image_gen_model_id or 0,
vision_model_id=search_space.vision_model_id or 0,
)
automation = Automation(
@ -225,9 +225,9 @@ class AutomationService:
"""
try:
assert_models_billable(
agent_llm_id=models.agent_llm_id,
image_generation_config_id=models.image_generation_config_id,
vision_llm_config_id=models.vision_llm_config_id,
chat_model_id=models.chat_model_id,
image_gen_model_id=models.image_gen_model_id,
vision_model_id=models.vision_model_id,
)
except AutomationModelPolicyError as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc

View file

@ -2,11 +2,11 @@
Automations run unattended, so every run must be **billable**: it may only use
either a premium global model (``billing_tier == "premium"``) or a user-provided
BYOK model (a positive config id pointing at a per-user/per-space DB row). Free
BYOK model (a positive model id pointing at a per-user/per-space DB row). Free
global models and Auto mode are blocked, because Auto can dispatch to a free
deployment and free models aren't metered in premium credits.
Config id conventions (shared across chat / image / vision):
Model id conventions (shared across chat / image / vision):
- ``id == 0`` Auto mode (``AUTO_MODE_ID`` / ``IMAGE_GEN_AUTO_MODE_ID`` /
``VISION_AUTO_MODE_ID``). Blocked.
- ``id < 0`` global YAML/OpenRouter config. Allowed only if premium.
@ -24,70 +24,45 @@ from typing import TYPE_CHECKING, Literal
if TYPE_CHECKING:
from app.db import SearchSpace
ModelKind = Literal["llm", "image", "vision"]
ModelKind = Literal["chat", "image", "vision"]
_KIND_LABEL: dict[ModelKind, str] = {
"llm": "agent LLM",
"chat": "chat model",
"image": "image generation model",
"vision": "vision model",
}
def _is_premium_global(kind: ModelKind, config_id: int) -> bool:
"""Return True if a negative (global) config id is a premium tier model."""
def _is_premium_global(model_id: int) -> bool:
"""Return True if a negative (global) model id is a premium tier model."""
from app.config import config as app_config
cfg: dict | None = None
if kind == "llm":
from app.agents.chat.runtime.llm_config import (
load_global_llm_config_by_id,
)
cfg = load_global_llm_config_by_id(config_id)
elif kind == "image":
cfg = next(
(
c
for c in app_config.GLOBAL_IMAGE_GEN_CONFIGS
if c.get("id") == config_id
),
None,
)
else: # vision
cfg = next(
(
c
for c in app_config.GLOBAL_VISION_LLM_CONFIGS
if c.get("id") == config_id
),
None,
)
if not cfg:
model = next((m for m in app_config.GLOBAL_MODELS if m.get("id") == model_id), None)
if not model:
return False
return str(cfg.get("billing_tier", "free")).lower() == "premium"
return str(model.get("billing_tier", "free")).lower() == "premium"
def _classify(kind: ModelKind, config_id: int | None) -> tuple[bool, str]:
"""Classify a resolved config id as allowed or blocked.
def _classify(kind: ModelKind, model_id: int | None) -> tuple[bool, str]:
"""Classify a resolved model id as allowed or blocked.
Returns ``(allowed, reason)``; ``reason`` is empty when allowed.
"""
label = _KIND_LABEL[kind]
if config_id is None or config_id == 0:
if model_id is None or model_id == 0:
return (
False,
f"The {label} is set to Auto mode. Automations require an explicit "
"premium model or your own (BYOK) model so every run is billable.",
)
if config_id > 0:
# Positive id → user-owned BYOK config. Always allowed.
if model_id > 0:
# Positive id -> user/search-space BYOK model. Always allowed.
return True, ""
# Negative id → global config. Allowed only if premium.
if _is_premium_global(kind, config_id):
# Negative id -> global model. Allowed only if premium.
if _is_premium_global(model_id):
return True, ""
return (
@ -99,27 +74,27 @@ def _classify(kind: ModelKind, config_id: int | None) -> tuple[bool, str]:
def get_model_eligibility(
*,
agent_llm_id: int | None,
image_generation_config_id: int | None,
vision_llm_config_id: int | None,
chat_model_id: int | None,
image_gen_model_id: int | None,
vision_model_id: int | None,
) -> dict:
"""Return ``{"allowed": bool, "violations": [...]}`` for explicit config ids.
"""Return ``{"allowed": bool, "violations": [...]}`` for explicit model ids.
The ID-based core shared by both the search-space path (creation/eligibility)
and the captured-snapshot path (runtime backstop). Each violation is
``{"kind", "config_id", "reason"}``.
``{"kind", "model_id", "reason"}``.
"""
checks: list[tuple[ModelKind, int | None]] = [
("llm", agent_llm_id),
("image", image_generation_config_id),
("vision", vision_llm_config_id),
("chat", chat_model_id),
("image", image_gen_model_id),
("vision", vision_model_id),
]
violations: list[dict] = []
for kind, config_id in checks:
allowed, reason = _classify(kind, config_id)
for kind, model_id in checks:
allowed, reason = _classify(kind, model_id)
if not allowed:
violations.append({"kind": kind, "config_id": config_id, "reason": reason})
violations.append({"kind": kind, "model_id": model_id, "reason": reason})
return {"allowed": not violations, "violations": violations}
@ -131,9 +106,9 @@ def get_automation_model_eligibility(search_space: SearchSpace) -> dict:
wrapper over :func:`get_model_eligibility`.
"""
return get_model_eligibility(
agent_llm_id=search_space.agent_llm_id,
image_generation_config_id=search_space.image_generation_config_id,
vision_llm_config_id=search_space.vision_llm_config_id,
chat_model_id=search_space.chat_model_id,
image_gen_model_id=search_space.image_gen_model_id,
vision_model_id=search_space.vision_model_id,
)
@ -150,9 +125,9 @@ class AutomationModelPolicyError(Exception):
def assert_models_billable(
*,
agent_llm_id: int | None,
image_generation_config_id: int | None,
vision_llm_config_id: int | None,
chat_model_id: int | None,
image_gen_model_id: int | None,
vision_model_id: int | None,
) -> None:
"""Raise :class:`AutomationModelPolicyError` if any explicit id is not billable.
@ -160,9 +135,9 @@ def assert_models_billable(
captured model snapshot.
"""
result = get_model_eligibility(
agent_llm_id=agent_llm_id,
image_generation_config_id=image_generation_config_id,
vision_llm_config_id=vision_llm_config_id,
chat_model_id=chat_model_id,
image_gen_model_id=image_gen_model_id,
vision_model_id=vision_model_id,
)
if not result["allowed"]:
raise AutomationModelPolicyError(result["violations"])

View file

@ -115,14 +115,12 @@ def init_worker(**kwargs):
initialize_llm_router,
initialize_openrouter_integration,
initialize_pricing_registration,
initialize_vision_llm_router,
)
initialize_openrouter_integration()
initialize_pricing_registration()
initialize_llm_router()
initialize_image_gen_router()
initialize_vision_llm_router()
# Celery configuration, sourced from the central Config singleton

View file

@ -78,8 +78,7 @@ def load_global_llm_configs():
# stamps) never leak into the cached YAML structure.
configs = copy.deepcopy(data.get("global_llm_configs", []))
# Lazy import keeps the `app.config` -> `app.services` edge one-way
# and matches the `provider_api_base` pattern used elsewhere.
# Lazy import keeps the `app.config` -> `app.services` edge one-way.
from app.services.provider_capabilities import derive_supports_image_input
seen_slugs: dict[str, int] = {}
@ -104,7 +103,7 @@ def load_global_llm_configs():
else None
)
cfg["supports_image_input"] = derive_supports_image_input(
provider=cfg.get("provider"),
provider=cfg.get("provider") or cfg.get("litellm_provider"),
model_name=cfg.get("model_name"),
base_model=base_model,
custom_provider=cfg.get("custom_provider"),
@ -120,10 +119,10 @@ def load_global_llm_configs():
else:
seen_slugs[slug] = cfg.get("id", 0)
# Stamp Auto (Fastest) ranking metadata. YAML configs are always
# Stamp Auto ranking metadata. YAML configs are always
# Tier A — operator-curated, locked first when premium-eligible.
# The OpenRouter refresh tick later re-stamps health for any cfg
# whose provider == "OPENROUTER" via _enrich_health.
# whose provider == "openrouter" via _enrich_health.
try:
from app.services.quality_score import static_score_yaml
@ -133,7 +132,7 @@ def load_global_llm_configs():
cfg["quality_score_static"] = static_q
cfg["quality_score"] = static_q
cfg["quality_score_health"] = None
# YAML cfgs whose provider is OPENROUTER are also subject
# YAML cfgs whose provider is openrouter are also subject
# to health gating against their own /endpoints data — a
# hand-picked dead OR model is still dead. _enrich_health
# re-stamps health_gated for them on the next refresh tick.
@ -211,42 +210,6 @@ def load_global_image_gen_configs():
return []
def load_global_vision_llm_configs():
data = _global_config_data()
if not data:
return []
try:
configs = copy.deepcopy(data.get("global_vision_llm_configs", []) or [])
for cfg in configs:
if isinstance(cfg, dict):
cfg.setdefault("billing_tier", "free")
return configs
except Exception as e:
print(f"Warning: Failed to load global vision LLM configs: {e}")
return []
def load_vision_llm_router_settings():
default_settings = {
"routing_strategy": "usage-based-routing",
"num_retries": 3,
"allowed_fails": 3,
"cooldown_time": 60,
}
data = _global_config_data()
if not data:
return default_settings
try:
settings = data.get("vision_llm_router_settings", {})
return {**default_settings, **settings}
except Exception as e:
print(f"Warning: Failed to load vision LLM router settings: {e}")
return default_settings
def load_image_gen_router_settings():
"""
Load router settings for image generation Auto mode from YAML file.
@ -363,8 +326,8 @@ def initialize_openrouter_integration():
else:
print("Info: OpenRouter integration enabled but no models fetched")
# Image generation + vision LLM emissions are opt-in (issue L).
# Both reuse the catalogue already cached by ``service.initialize``
# Image generation emissions reuse the catalogue already cached by
# ``service.initialize``
# so we don't make additional network calls here.
if settings.get("image_generation_enabled"):
try:
@ -378,21 +341,26 @@ def initialize_openrouter_integration():
except Exception as e:
print(f"Warning: Failed to inject OpenRouter image-gen configs: {e}")
if settings.get("vision_enabled"):
try:
vision_configs = service.get_vision_llm_configs()
if vision_configs:
config.GLOBAL_VISION_LLM_CONFIGS.extend(vision_configs)
print(
f"Info: OpenRouter integration added {len(vision_configs)} "
f"vision LLM models"
)
except Exception as e:
print(f"Warning: Failed to inject OpenRouter vision-LLM configs: {e}")
refresh_global_model_catalog()
except Exception as e:
print(f"Warning: Failed to initialize OpenRouter integration: {e}")
def materialize_global_configs():
from app.services.global_model_catalog import materialize_global_model_catalog
return materialize_global_model_catalog(
chat_configs=getattr(config, "GLOBAL_LLM_CONFIGS", []),
image_configs=getattr(config, "GLOBAL_IMAGE_GEN_CONFIGS", []),
)
def refresh_global_model_catalog():
connections, models = materialize_global_configs()
config.GLOBAL_CONNECTIONS = connections
config.GLOBAL_MODELS = models
def initialize_pricing_registration():
"""
Teach LiteLLM the per-token cost of every deployment in
@ -430,7 +398,10 @@ def initialize_llm_router():
router_settings = config.ROUTER_SETTINGS
if not all_configs:
print("Info: No global LLM configs found, Auto mode will not be available")
print(
"Info: No global LLM configs found; global Auto pool is unavailable. "
"Auto can still use enabled BYOK models."
)
return
try:
@ -475,32 +446,6 @@ def initialize_image_gen_router():
print(f"Warning: Failed to initialize Image Generation Router: {e}")
def initialize_vision_llm_router():
vision_configs = load_global_vision_llm_configs()
# Reuse the router settings already parsed at Config construction. The
# *configs* list is intentionally re-read from YAML (it must exclude the
# OpenRouter-injected dynamic models held in config.GLOBAL_VISION_LLM_CONFIGS).
router_settings = config.VISION_LLM_ROUTER_SETTINGS
if not vision_configs:
print(
"Info: No global vision LLM configs found, "
"Vision LLM Auto mode will not be available"
)
return
try:
from app.services.vision_llm_router_service import VisionLLMRouterService
VisionLLMRouterService.initialize(vision_configs, router_settings)
print(
f"Info: Vision LLM Router initialized with {len(vision_configs)} models "
f"(strategy: {router_settings.get('routing_strategy', 'usage-based-routing')})"
)
except Exception as e:
print(f"Warning: Failed to initialize Vision LLM Router: {e}")
class Config:
# Check if ffmpeg is installed
if not is_ffmpeg_installed():
@ -762,7 +707,7 @@ class Config:
os.getenv("QUOTA_DEFAULT_IMAGE_RESERVE_MICROS", "50000")
)
# Per-podcast reservation (in micro-USD). One agent LLM call generating
# Per-podcast reservation (in micro-USD). One chat model call generating
# a transcript, typically 5k-20k completion tokens. $0.20 covers a long
# premium-model run. Tune via env.
QUOTA_DEFAULT_PODCAST_RESERVE_MICROS = int(
@ -882,11 +827,17 @@ class Config:
# Router settings for Image Generation Auto mode
IMAGE_GEN_ROUTER_SETTINGS = load_image_gen_router_settings()
# Global Vision LLM Configurations (optional)
GLOBAL_VISION_LLM_CONFIGS = load_global_vision_llm_configs()
# Virtual GLOBAL connection/model catalog. This is server-only metadata
# derived from global_llm_config.yaml; GLOBAL keys are not stored in DB.
from app.services.global_model_catalog import (
materialize_global_model_catalog as _materialize_global_model_catalog,
)
# Router settings for Vision LLM Auto mode
VISION_LLM_ROUTER_SETTINGS = load_vision_llm_router_settings()
GLOBAL_CONNECTIONS, GLOBAL_MODELS = _materialize_global_model_catalog(
chat_configs=GLOBAL_LLM_CONFIGS,
image_configs=GLOBAL_IMAGE_GEN_CONFIGS,
)
del _materialize_global_model_catalog
# OpenRouter Integration settings (optional)
OPENROUTER_INTEGRATION_SETTINGS = load_openrouter_integration_settings()

View file

@ -1,362 +1,236 @@
# Global LLM Configuration
#
# SETUP INSTRUCTIONS:
# 1. For production: Copy this file to global_llm_config.yaml and add your real API keys
# 2. For testing: The system will use this example file automatically if global_llm_config.yaml doesn't exist
# 1. Copy this file to global_llm_config.yaml.
# 2. Replace placeholder credentials, endpoints, deployment names, and pricing
# with values from your own provider accounts.
#
# NOTE: The example API keys below are placeholders and won't work.
# Replace them with your actual API keys to enable global configurations.
# This file is intentionally safe to commit. Do not put real API keys in this
# example file.
#
# These configurations will be available to all users as a convenient option
# Users can choose to use these global configs or add their own
# These YAML entries are materialized at startup as server-owned GLOBAL
# connections and models:
#
# AUTO MODE (Recommended):
# - Auto mode (ID: 0) uses LiteLLM Router to automatically load balance across all global configs
# - This helps avoid rate limits by distributing requests across multiple providers
# - New users are automatically assigned Auto mode by default
# - Configure router_settings below to customize the load balancing behavior
# global_llm_configs -> GLOBAL chat models
# global_image_generation_configs -> GLOBAL image generation models
#
# Structure matches NewLLMConfig:
# - Model configuration (provider, model_name, api_key, etc.)
# - Prompt configuration (system_instructions, citations_enabled)
# Do not add global_connections or global_models sections here. They are
# runtime-derived metadata exposed through the model-connections APIs.
#
# Static config shape:
# - Connection fields: provider, api_key, api_base, api_version
# - Model fields: model_name, billing_tier, rpm/tpm, capabilities, litellm_params
# - Public no-login SEO metadata: seo_title, seo_description
# - Prompt defaults: system_instructions, use_default_system_instructions,
# citations_enabled
#
# Provider notes:
# - Use the canonical provider field.
# - For Azure, use the bare deployment name in model_name, for example
# model_name: "gpt-5.1". The resolver prefixes the LiteLLM model string from
# provider: "azure".
#
# GLOBAL ID namespace:
# - ID 0 is reserved for Auto mode.
# - Negative IDs are server-owned GLOBAL models.
# - Positive IDs are user/BYOK database models.
# - Keep static IDs unique across chat and image generation.
# - Suggested static ranges: chat -1..-999, image -2001..-2999.
# - Vision is not a separate config/table. Chat models that accept images use
# supports_image_input: true.
#
# COST-BASED PREMIUM CREDITS:
# Each premium config bills the user's USD-credit balance based on the
# actual provider cost reported by LiteLLM. For models LiteLLM already
# knows (most OpenAI/Anthropic/etc. names) you don't need to do anything.
# For custom Azure deployment names (e.g. an in-house "gpt-5.4" deployment)
# or any model LiteLLM doesn't have in its built-in pricing table, declare
# per-token costs inline so they bill correctly:
# Each premium model bills the user's USD-credit balance based on provider cost
# reported by LiteLLM. For custom Azure deployments or any model LiteLLM does
# not know, declare per-token costs inline:
#
# litellm_params:
# base_model: "my-custom-azure-deploy"
# # USD per token; e.g. 0.000003 == $3.00 per million input tokens
# input_cost_per_token: 0.000003
# output_cost_per_token: 0.000015
# base_model: "my-custom-deployment"
# # USD per token; 0.00000125 == $1.25 per million input tokens.
# input_cost_per_token: 0.00000125
# output_cost_per_token: 0.00001
#
# OpenRouter dynamic models pull pricing automatically from OpenRouter's
# API — no inline declaration needed. Models without resolvable pricing
# debit $0 from the user's balance and log a WARNING.
# OpenRouter dynamic chat models pull pricing automatically from OpenRouter's
# API. Models without resolvable pricing debit $0 and log a warning.
# Router Settings for Auto Mode
# These settings control how the LiteLLM Router distributes requests across models
# =============================================================================
# Chat Auto Mode Router Settings
# =============================================================================
# These settings control how the LiteLLM Router distributes Auto-mode requests
# across curated router-eligible GLOBAL chat deployments.
router_settings:
# Routing strategy options:
# - "usage-based-routing": Routes to deployment with lowest current usage (recommended for rate limits)
# - "simple-shuffle": Random distribution with optional RPM/TPM weighting
# - "least-busy": Routes to least busy deployment
# - "latency-based-routing": Routes based on response latency
# - "usage-based-routing": Routes to deployment with lowest current usage.
# - "simple-shuffle": Random distribution with optional RPM/TPM weighting.
# - "least-busy": Routes to least busy deployment.
# - "latency-based-routing": Routes based on response latency.
routing_strategy: "usage-based-routing"
# Number of retries before failing
num_retries: 3
# Number of failures allowed before cooling down a deployment
allowed_fails: 3
# Cooldown time in seconds after allowed_fails is exceeded
cooldown_time: 60
# Optional fallback map:
# fallbacks:
# - {"azure/gpt-5.1": ["azure/gpt-5.4-mini"]}
# Fallback models (optional) - when primary fails, try these
# Format: [{"primary_model": ["fallback1", "fallback2"]}]
# fallbacks: []
# =============================================================================
# Static GLOBAL Chat Models
# =============================================================================
global_llm_configs:
# Example: OpenAI GPT-4 Turbo with citations enabled
# Premium Azure chat model with image input support and explicit custom
# pricing. This is the current shape to use for hosted GPT 5.x deployments.
- id: -1
name: "Global GPT-4 Turbo"
description: "OpenAI's GPT-4 Turbo with default prompts and citations"
billing_tier: "free"
anonymous_enabled: true
seo_enabled: true
seo_slug: "gpt-4-turbo"
name: "Azure GPT 5.1"
billing_tier: "premium"
anonymous_enabled: false
seo_enabled: false
seo_slug: "azure-gpt-5-1"
quota_reserve_tokens: 4000
provider: "OPENAI"
model_name: "gpt-4-turbo-preview"
api_key: "sk-your-openai-api-key-here"
api_base: ""
# Rate limits for load balancing (requests/tokens per minute)
rpm: 500 # Requests per minute
tpm: 100000 # Tokens per minute
provider: "azure"
model_name: "gpt-5.1"
supports_image_input: true
supports_tools: true
max_input_tokens: 400000
api_key: "your-azure-api-key-here"
api_base: "https://your-resource.openai.azure.com"
# api_version is optional. Include it if your Azure deployment requires a
# specific API version.
# api_version: "2025-04-01-preview"
rpm: 47500
tpm: 14750000
litellm_params:
temperature: 0.7
max_tokens: 4000
# Prompt Configuration
system_instructions: "" # Empty = use default SURFSENSE_SYSTEM_INSTRUCTIONS
max_tokens: 16384
base_model: "gpt-5.1"
input_cost_per_token: 0.00000125
output_cost_per_token: 0.00001
system_instructions: ""
use_default_system_instructions: true
citations_enabled: true
# Example: Anthropic Claude 3 Opus
# Larger premium chat model. If your provider prices long-context traffic
# differently, choose a conservative flat price or document the limitation
# next to the inline pricing.
- id: -2
name: "Global Claude 3 Opus"
description: "Anthropic's most capable model with citations"
billing_tier: "free"
anonymous_enabled: true
seo_enabled: true
seo_slug: "claude-3-opus"
name: "Azure GPT 5.4"
billing_tier: "premium"
anonymous_enabled: false
seo_enabled: false
seo_slug: "azure-gpt-5-4"
quota_reserve_tokens: 4000
provider: "ANTHROPIC"
model_name: "claude-3-opus-20240229"
api_key: "sk-ant-your-anthropic-api-key-here"
api_base: ""
rpm: 1000
tpm: 100000
provider: "azure"
model_name: "gpt-5.4"
supports_image_input: true
supports_tools: true
max_input_tokens: 400000
api_key: "your-azure-api-key-here"
api_base: "https://your-resource.openai.azure.com"
rpm: 150000
tpm: 15000000
litellm_params:
temperature: 0.7
max_tokens: 4000
max_tokens: 16384
base_model: "gpt-5.4"
input_cost_per_token: 0.0000025
output_cost_per_token: 0.000015
system_instructions: ""
use_default_system_instructions: true
citations_enabled: true
# Example: Fast model - GPT-3.5 Turbo (citations disabled for speed)
# Free/no-login hosted model. Free models are visible to users when
# anonymous_enabled/seo_enabled are true but do not debit premium credits.
- id: -3
name: "Global GPT-3.5 Turbo (Fast)"
description: "Fast responses without citations for quick queries"
name: "Azure GPT 5.4 Mini"
billing_tier: "free"
anonymous_enabled: true
seo_enabled: true
seo_slug: "gpt-3.5-turbo-fast"
quota_reserve_tokens: 2000
provider: "OPENAI"
model_name: "gpt-3.5-turbo"
api_key: "sk-your-openai-api-key-here"
api_base: ""
rpm: 3500 # GPT-3.5 has higher rate limits
tpm: 200000
litellm_params:
temperature: 0.5
max_tokens: 2000
system_instructions: ""
use_default_system_instructions: true
citations_enabled: false # Disabled for faster responses
# Example: Chinese LLM - DeepSeek with custom instructions
- id: -4
name: "Global DeepSeek Chat (Chinese)"
description: "DeepSeek optimized for Chinese language responses"
billing_tier: "free"
anonymous_enabled: true
seo_enabled: true
seo_slug: "deepseek-chat-chinese"
seo_slug: "gpt-5-4-mini-no-login"
seo_title: "Free GPT 5.4 Mini Chat"
seo_description: "Chat with a hosted GPT 5.4 Mini model without signing in."
quota_reserve_tokens: 4000
provider: "DEEPSEEK"
model_name: "deepseek-chat"
api_key: "your-deepseek-api-key-here"
api_base: "https://api.deepseek.com/v1"
rpm: 60
tpm: 100000
litellm_params:
temperature: 0.7
max_tokens: 4000
# Custom system instructions for Chinese responses
system_instructions: |
<system_instruction>
You are SurfSense, a reasoning and acting AI agent designed to answer user questions using the user's personal knowledge base.
Today's date (UTC): {resolved_today}
IMPORTANT: Please respond in Chinese (简体中文) unless the user specifically requests another language.
</system_instruction>
use_default_system_instructions: false
citations_enabled: true
# Example: Azure OpenAI GPT-4o
# IMPORTANT: For Azure deployments, always include 'base_model' in litellm_params
# to enable accurate token counting, cost tracking, and max token limits
- id: -5
name: "Global Azure GPT-4o"
description: "Azure OpenAI GPT-4o deployment"
billing_tier: "free"
anonymous_enabled: true
seo_enabled: true
seo_slug: "azure-gpt-4o"
quota_reserve_tokens: 4000
provider: "AZURE"
# model_name format for Azure: azure/<your-deployment-name>
model_name: "azure/gpt-4o-deployment"
provider: "azure"
model_name: "gpt-5.4-mini"
supports_image_input: false
supports_tools: true
max_input_tokens: 128000
api_key: "your-azure-api-key-here"
api_base: "https://your-resource.openai.azure.com"
api_version: "2024-02-15-preview" # Azure API version
rpm: 1000
tpm: 150000
rpm: 15000
tpm: 15000000
litellm_params:
temperature: 0.7
max_tokens: 4000
# REQUIRED for Azure: Specify the underlying OpenAI model
# This fixes "Could not identify azure model" warnings
# Common base_model values: gpt-4, gpt-4-turbo, gpt-4o, gpt-4o-mini, gpt-3.5-turbo
base_model: "gpt-4o"
max_tokens: 16384
base_model: "gpt-5.4-mini"
system_instructions: ""
use_default_system_instructions: true
citations_enabled: true
# Example: Azure OpenAI GPT-4 Turbo
- id: -6
name: "Global Azure GPT-4 Turbo"
description: "Azure OpenAI GPT-4 Turbo deployment"
billing_tier: "free"
anonymous_enabled: true
seo_enabled: true
seo_slug: "azure-gpt-4-turbo"
quota_reserve_tokens: 4000
provider: "AZURE"
model_name: "azure/gpt-4-turbo-deployment"
api_key: "your-azure-api-key-here"
api_base: "https://your-resource.openai.azure.com"
api_version: "2024-02-15-preview"
rpm: 500
tpm: 100000
litellm_params:
temperature: 0.7
max_tokens: 4000
base_model: "gpt-4-turbo" # Maps to gpt-4-turbo-preview
system_instructions: ""
use_default_system_instructions: true
citations_enabled: true
# Example: Groq - Fast inference
- id: -7
name: "Global Groq Llama 3"
description: "Ultra-fast Llama 3 70B via Groq"
billing_tier: "free"
anonymous_enabled: true
seo_enabled: true
seo_slug: "groq-llama-3"
quota_reserve_tokens: 8000
provider: "GROQ"
model_name: "llama3-70b-8192"
api_key: "your-groq-api-key-here"
api_base: ""
rpm: 30 # Groq has lower rate limits on free tier
tpm: 14400
litellm_params:
temperature: 0.7
max_tokens: 8000
system_instructions: ""
use_default_system_instructions: true
citations_enabled: true
# Example: MiniMax M3 - High-performance with 512K context window
- id: -8
name: "Global MiniMax M3"
description: "MiniMax M3 with 512K context window and competitive pricing"
billing_tier: "free"
anonymous_enabled: true
seo_enabled: true
seo_slug: "minimax-m3"
quota_reserve_tokens: 4000
provider: "MINIMAX"
model_name: "MiniMax-M3"
api_key: "your-minimax-api-key-here"
api_base: "https://api.minimax.io/v1"
rpm: 60
tpm: 100000
litellm_params:
temperature: 1.0 # MiniMax requires temperature in (0.0, 1.0], cannot be 0
max_tokens: 4000
system_instructions: ""
use_default_system_instructions: true
citations_enabled: true
# Example: Planner LLM - small, fast model used for internal utility tasks
#
# The PLANNER role handles short, structured internal calls (KB query
# rewriting, date extraction, recency classification, etc.) that don't
# need frontier-tier capability. Pointing the planner at a cheap+fast
# model (gpt-4o-mini, Claude Haiku, Azure gpt-5.x-nano, Groq Llama, ...)
# typically saves 500ms-1.5s per turn vs. routing those same internal
# calls through the user's chat model.
#
# Activation:
# - Mark EXACTLY ONE global config with ``is_planner: true``.
# - If multiple are marked, the first one wins and a WARNING is logged.
# - If none is marked, every internal call falls back to the user's
# chat LLM (same behavior as before this flag existed).
#
# This config is operator-only — it is NOT exposed in the user-facing
# model selector, never billed against premium quota, and the
# billing_tier / anonymous_enabled fields below are ignored.
# Planner LLM. This is operator-only and is not shown in the user-facing
# model selector. Only one global_llm_configs entry should set is_planner.
- id: -9
name: "Global Planner (GPT-4o mini)"
description: "Internal-only planner LLM for query rewriting and classification"
name: "Azure GPT 5.x Nano Planner"
is_planner: true
billing_tier: "free"
anonymous_enabled: false
seo_enabled: false
quota_reserve_tokens: 1000
provider: "OPENAI"
model_name: "gpt-4o-mini"
api_key: "sk-your-openai-api-key-here"
api_base: ""
rpm: 3500
tpm: 200000
provider: "azure"
model_name: "gpt-5.4-nano"
supports_image_input: false
supports_tools: false
router_pool_eligible: false
api_key: "your-azure-api-key-here"
api_base: "https://your-resource.openai.azure.com"
rpm: 20000
tpm: 4000000
litellm_params:
temperature: 0
max_tokens: 1000
base_model: "gpt-5.4-nano"
system_instructions: ""
use_default_system_instructions: true
citations_enabled: false
# =============================================================================
# OpenRouter Integration
# OpenRouter Dynamic Model Integration
# =============================================================================
# When enabled, dynamically fetches ALL available models from the OpenRouter API
# and injects them as global configs. This gives premium users access to any model
# on OpenRouter (Claude, Gemini, Llama, Mistral, etc.) via their premium token quota,
# while free-tier OpenRouter models show up with a green Free badge and do NOT
# consume premium quota.
# Models are fetched at startup and refreshed periodically in the background.
# All calls go through LiteLLM with the openrouter/ prefix.
# When enabled, SurfSense fetches the OpenRouter catalog at startup and injects
# supported models as GLOBAL chat and optionally image-generation models.
# Tier is derived per model from OpenRouter data:
# - model id ends with ":free" -> billing_tier=free
# - prompt and completion pricing are zero -> billing_tier=free
# - otherwise -> billing_tier=premium
#
# Do not use deprecated openrouter_integration.billing_tier or
# openrouter_integration.anonymous_enabled. Use the tier-specific anonymous
# switches below.
openrouter_integration:
enabled: false
api_key: "sk-or-your-openrouter-api-key"
# Tier is derived PER MODEL from OpenRouter's own API signals:
# - id ends with ":free" -> billing_tier=free
# - pricing.prompt AND pricing.completion == "0" -> billing_tier=free
# - otherwise -> billing_tier=premium
# No global billing_tier knob is honored; any legacy value emits a startup warning.
# Anonymous access is split by tier so operators can expose only free
# models to no-login users without leaking paid inference.
anonymous_enabled_paid: false
anonymous_enabled_free: false
seo_enabled: false
# quota_reserve_tokens: tokens reserved per call for quota enforcement
quota_reserve_tokens: 4000
# id_offset: base negative ID for dynamically generated configs.
# Model IDs are derived deterministically via BLAKE2b so they survive
# catalogue churn. Must not overlap with your static global_llm_configs IDs.
# Base negative ID namespace for dynamic chat models. IDs are derived
# deterministically so they survive catalog churn. Do not overlap static IDs.
id_offset: -10000
# refresh_interval_hours: how often to re-fetch models from OpenRouter (0 = startup only)
# Separate base negative ID namespace for dynamic image-generation models.
image_id_offset: -20000
# How often to refresh the OpenRouter catalog. 0 means startup only.
refresh_interval_hours: 24
# Rate limits for PAID OpenRouter models. These are used by LiteLLM Router
# for per-deployment accounting when OR premium models participate in the
# shared sub-agent "auto" pool. They do NOT cap OpenRouter itself — your
# real account limits live at https://openrouter.ai/settings/limits.
# Paid OpenRouter models may join curated router pools when eligible.
rpm: 200
tpm: 1000000
# Rate limits for FREE OpenRouter models. Informational only: free OR
# models are intentionally kept OUT of the LiteLLM Router pool, because
# OpenRouter enforces free-tier limits globally per account (~20 RPM +
# 50-1000 daily requests across every ":free" model combined) —
# per-deployment router accounting can't represent a shared bucket
# correctly. Free OR models stay fully available in the model selector
# and for user-facing Auto thread pinning.
# Free OpenRouter models are available for user-facing selection/pinning but
# should be treated as a shared-account bucket, not normal router capacity.
free_rpm: 20
free_tpm: 100000
# Image generation + vision LLM emission are OPT-IN. OpenRouter's catalogue
# contains hundreds of image- and vision-capable models; turning these on
# injects them into the global Image-Generation / Vision-LLM model
# selectors alongside any static configs. Tier (free/premium) is derived
# per model the same way it is for chat (`:free` suffix or zero pricing).
# When a user picks a premium image/vision model the call debits the
# shared $5 USD-cost-based premium credit pool — so leaving these off
# avoids surprise quota burn on existing deployments. Default: false.
# Image generation is opt-in to avoid injecting a large image catalog during
# upgrades. Vision-capable chat models are represented with
# supports_image_input: true.
image_generation_enabled: false
vision_enabled: false
@ -367,191 +241,80 @@ openrouter_integration:
citations_enabled: true
# =============================================================================
# Image Generation Configuration
# Image Generation Auto Mode Router Settings
# =============================================================================
# These configurations power the image generation feature using litellm.aimage_generation().
# Supported providers: OpenAI, Azure, Google AI Studio, Vertex AI, AWS Bedrock,
# Recraft, OpenRouter, Xinference, Nscale
#
# Auto mode (ID 0) uses LiteLLM Router for load balancing across all image gen configs.
# Router Settings for Image Generation Auto Mode
image_generation_router_settings:
routing_strategy: "usage-based-routing"
num_retries: 3
allowed_fails: 3
cooldown_time: 60
# =============================================================================
# Static GLOBAL Image Generation Models
# =============================================================================
global_image_generation_configs:
# Example: OpenAI DALL-E 3
- id: -1
name: "Global DALL-E 3"
description: "OpenAI's DALL-E 3 for high-quality image generation"
provider: "OPENAI"
model_name: "dall-e-3"
api_key: "sk-your-openai-api-key-here"
api_base: ""
rpm: 50 # Requests per minute (image gen is rate-limited by RPM, not tokens)
litellm_params: {}
# Example: OpenAI GPT Image 1
- id: -2
name: "Global GPT Image 1"
description: "OpenAI's GPT Image 1 model"
provider: "OPENAI"
model_name: "gpt-image-1"
api_key: "sk-your-openai-api-key-here"
api_base: ""
rpm: 50
litellm_params: {}
# Example: Azure OpenAI DALL-E 3
- id: -3
name: "Global Azure DALL-E 3"
description: "Azure-hosted DALL-E 3 deployment"
provider: "AZURE_OPENAI"
model_name: "azure/dall-e-3-deployment"
- id: -2001
name: "Azure GPT Image 1.5"
billing_tier: "premium"
provider: "azure"
model_name: "gpt-image-1.5"
api_key: "your-azure-api-key-here"
api_base: "https://your-resource.openai.azure.com"
api_version: "2024-02-15-preview"
rpm: 50
# api_version: "2025-04-01-preview"
rpm: 60
litellm_params:
base_model: "dall-e-3"
base_model: "gpt-image-1.5"
# Example: OpenRouter Gemini Image Generation
# - id: -4
# name: "Global Gemini Image Gen"
# description: "Google Gemini image generation via OpenRouter"
# provider: "OPENROUTER"
# model_name: "google/gemini-2.5-flash-image"
# api_key: "your-openrouter-api-key-here"
# api_base: ""
# rpm: 30
# litellm_params: {}
- id: -2002
name: "Azure GPT Image 1 Mini"
billing_tier: "free"
provider: "azure"
model_name: "gpt-image-1-mini"
api_key: "your-azure-api-key-here"
api_base: "https://your-resource.openai.azure.com"
# api_version: "2025-04-01-preview"
rpm: 120
litellm_params:
base_model: "gpt-image-1-mini"
# =============================================================================
# Vision LLM Configuration
# Field Notes
# =============================================================================
# These configurations power the vision autocomplete feature (screenshot analysis).
# Only vision-capable models should be used here (e.g. GPT-4o, Gemini Pro, Claude 3).
# Supported providers: OpenAI, Anthropic, Google, Azure OpenAI, Vertex AI, Bedrock,
# xAI, OpenRouter, Ollama, Groq, Together AI, Fireworks AI, DeepSeek, Mistral, Custom
# Common chat/image fields:
# - provider: Canonical provider adapter name. Example: azure, openai,
# anthropic, openrouter, groq, bedrock.
# - model_name: Provider model or deployment id. For Azure, use the bare
# deployment name. The resolver prefixes LiteLLM model strings from provider.
# - api_base: Provider endpoint/root URL. For OpenAI-compatible providers, the
# resolver adds /v1 when needed.
# - api_version: Optional provider-specific API version, stored on the
# materialized connection extra metadata.
# - litellm_params: Passed to LiteLLM when invoking the model. Also used for
# base_model and inline pricing registration.
#
# Auto mode (ID 0) uses LiteLLM Router for load balancing across all vision configs.
# Router Settings for Vision LLM Auto Mode
vision_llm_router_settings:
routing_strategy: "usage-based-routing"
num_retries: 3
allowed_fails: 3
cooldown_time: 60
global_vision_llm_configs:
# Example: OpenAI GPT-4o (recommended for vision)
- id: -1
name: "Global GPT-4o Vision"
description: "OpenAI's GPT-4o with strong vision capabilities"
provider: "OPENAI"
model_name: "gpt-4o"
api_key: "sk-your-openai-api-key-here"
api_base: ""
rpm: 500
tpm: 100000
litellm_params:
temperature: 0.3
max_tokens: 1000
# Example: Google Gemini 2.0 Flash
- id: -2
name: "Global Gemini 2.0 Flash"
description: "Google's fast vision model with large context"
provider: "GOOGLE"
model_name: "gemini-2.0-flash"
api_key: "your-google-ai-api-key-here"
api_base: ""
rpm: 1000
tpm: 200000
litellm_params:
temperature: 0.3
max_tokens: 1000
# Example: Anthropic Claude 3.5 Sonnet
- id: -3
name: "Global Claude 3.5 Sonnet Vision"
description: "Anthropic's Claude 3.5 Sonnet with vision support"
provider: "ANTHROPIC"
model_name: "claude-3-5-sonnet-20241022"
api_key: "sk-ant-your-anthropic-api-key-here"
api_base: ""
rpm: 1000
tpm: 100000
litellm_params:
temperature: 0.3
max_tokens: 1000
# Example: Azure OpenAI GPT-4o
# - id: -4
# name: "Global Azure GPT-4o Vision"
# description: "Azure-hosted GPT-4o for vision analysis"
# provider: "AZURE_OPENAI"
# model_name: "azure/gpt-4o-deployment"
# api_key: "your-azure-api-key-here"
# api_base: "https://your-resource.openai.azure.com"
# api_version: "2024-02-15-preview"
# rpm: 500
# tpm: 100000
# litellm_params:
# temperature: 0.3
# max_tokens: 1000
# base_model: "gpt-4o"
# Notes:
# - ID 0 is reserved for "Auto" mode - uses LiteLLM Router for load balancing
# - Use negative IDs to distinguish global configs from user configs (NewLLMConfig in DB)
# - IDs should be unique and sequential (e.g., -1, -2, -3, etc.)
# - The 'api_key' field will not be exposed to users via API
# - system_instructions: Custom prompt or empty string to use defaults
# - use_default_system_instructions: true = use SURFSENSE_SYSTEM_INSTRUCTIONS when system_instructions is empty
# - citations_enabled: true = include citation instructions, false = include anti-citation instructions
# - All standard LiteLLM providers are supported
# - rpm/tpm: Optional rate limits for load balancing (requests/tokens per minute)
# These help the router distribute load evenly and avoid rate limit errors
# Chat model fields:
# - supports_image_input: true when the chat model can consume image inputs.
# - supports_tools: true when the model can use tools/function calling.
# - max_input_tokens: Optional UI/catalog metadata for context size.
# - router_pool_eligible: false keeps a model out of shared router pools while
# still allowing direct selection/pinning.
# - is_planner: true marks the internal-only planner model. Only one config
# should set this flag.
#
# Catalog and access fields:
# - billing_tier: "free" or "premium".
# - anonymous_enabled: Whether the model appears in the public no-login catalog.
# - seo_enabled: Whether a /free/<seo_slug> landing page is generated.
# - seo_slug: Stable URL slug for SEO pages. Keep unique and do not change once
# public.
# - seo_title / seo_description: Optional SEO metadata overrides.
# - quota_reserve_tokens: Tokens reserved before each chat LLM call.
# - rpm / tpm: Optional rate limits for router accounting and load balancing.
#
# IMAGE GENERATION NOTES:
# - Image generation configs use the same ID scheme as LLM configs (negative for global)
# - Supported models: dall-e-2, dall-e-3, gpt-image-1 (OpenAI), azure/* (Azure),
# bedrock/* (AWS), vertex_ai/* (Google), recraft/* (Recraft), openrouter/* (OpenRouter)
# - The router uses litellm.aimage_generation() for async image generation
# - Only RPM (requests per minute) is relevant for image generation rate limiting.
# TPM (tokens per minute) does not apply since image APIs are billed/rate-limited per request, not per token.
#
# VISION LLM NOTES:
# - Vision configs use the same ID scheme (negative for global, positive for user DB)
# - Only use vision-capable models (GPT-4o, Gemini, Claude 3, etc.)
# - Lower temperature (0.3) is recommended for accurate screenshot analysis
# - Lower max_tokens (1000) is sufficient since autocomplete produces short suggestions
#
# PLANNER LLM NOTES:
# - is_planner: true marks a config as the internal-only planner LLM (small,
# fast model used for KB query rewriting, date extraction, recency
# classification, etc.). Only one config may carry this flag — if
# multiple do, the first one wins and a startup WARNING is logged.
# - When no config is marked is_planner, every internal utility call falls
# back to the user's chat LLM (the historical behavior).
# - Planner configs are NOT shown in the user-facing model selector and
# are NOT billed against the user's premium quota. Their billing_tier,
# anonymous_enabled, seo_* fields are ignored.
# - Recommended models: gpt-4o-mini, claude-3-5-haiku, gemini-1.5-flash,
# azure gpt-5.x-nano, groq llama3-8b — anything <200ms p50 on a 1-2k
# prompt. Frontier models here defeat the purpose of the flag.
#
# TOKEN QUOTA & ANONYMOUS ACCESS NOTES:
# - billing_tier: "free" or "premium". Controls whether registered users need premium token quota.
# - anonymous_enabled: true/false. Whether the model appears in the public no-login catalog.
# - seo_enabled: true/false. Whether a /free/<seo_slug> landing page is generated.
# - seo_slug: Stable URL slug for SEO pages. Must be unique. Do NOT change once public.
# - seo_title: Optional HTML title tag override for the model's /free/<slug> page.
# - seo_description: Optional meta description override for the model's /free/<slug> page.
# - quota_reserve_tokens: Tokens reserved before each LLM call for quota enforcement.
# Independent of litellm_params.max_tokens. Used by the token quota service.
# Image generation notes:
# - Image-generation configs use the same GLOBAL ID namespace as chat models.
# - Only RPM is relevant for most image-generation APIs.
# - The runtime uses litellm.aimage_generation().
# - Image billing currently uses billing_tier and model catalog metadata. Keep
# quota reserve tuning in code/catalog unless the materializer copies a YAML
# key for image quota reservation.

View file

@ -198,79 +198,15 @@ class DocumentStatus:
return None
class LiteLLMProvider(StrEnum):
"""
Enum for LLM providers supported by LiteLLM.
"""
OPENAI = "OPENAI"
ANTHROPIC = "ANTHROPIC"
GOOGLE = "GOOGLE"
AZURE_OPENAI = "AZURE_OPENAI"
BEDROCK = "BEDROCK"
VERTEX_AI = "VERTEX_AI"
GROQ = "GROQ"
COHERE = "COHERE"
MISTRAL = "MISTRAL"
DEEPSEEK = "DEEPSEEK"
XAI = "XAI"
OPENROUTER = "OPENROUTER"
TOGETHER_AI = "TOGETHER_AI"
FIREWORKS_AI = "FIREWORKS_AI"
REPLICATE = "REPLICATE"
PERPLEXITY = "PERPLEXITY"
OLLAMA = "OLLAMA"
ALIBABA_QWEN = "ALIBABA_QWEN"
MOONSHOT = "MOONSHOT"
ZHIPU = "ZHIPU"
ANYSCALE = "ANYSCALE"
DEEPINFRA = "DEEPINFRA"
CEREBRAS = "CEREBRAS"
SAMBANOVA = "SAMBANOVA"
AI21 = "AI21"
CLOUDFLARE = "CLOUDFLARE"
DATABRICKS = "DATABRICKS"
COMETAPI = "COMETAPI"
HUGGINGFACE = "HUGGINGFACE"
GITHUB_MODELS = "GITHUB_MODELS"
MINIMAX = "MINIMAX"
CUSTOM = "CUSTOM"
class ConnectionScope(StrEnum):
GLOBAL = "GLOBAL"
SEARCH_SPACE = "SEARCH_SPACE"
USER = "USER"
class ImageGenProvider(StrEnum):
"""
Enum for image generation providers supported by LiteLLM.
This is a subset of LLM providers only those that support image generation.
See: https://docs.litellm.ai/docs/image_generation#supported-providers
"""
OPENAI = "OPENAI"
AZURE_OPENAI = "AZURE_OPENAI"
GOOGLE = "GOOGLE" # Google AI Studio
VERTEX_AI = "VERTEX_AI"
BEDROCK = "BEDROCK" # AWS Bedrock
RECRAFT = "RECRAFT"
OPENROUTER = "OPENROUTER"
XINFERENCE = "XINFERENCE"
NSCALE = "NSCALE"
class VisionProvider(StrEnum):
OPENAI = "OPENAI"
ANTHROPIC = "ANTHROPIC"
GOOGLE = "GOOGLE"
AZURE_OPENAI = "AZURE_OPENAI"
VERTEX_AI = "VERTEX_AI"
BEDROCK = "BEDROCK"
XAI = "XAI"
OPENROUTER = "OPENROUTER"
OLLAMA = "OLLAMA"
GROQ = "GROQ"
TOGETHER_AI = "TOGETHER_AI"
FIREWORKS_AI = "FIREWORKS_AI"
DEEPSEEK = "DEEPSEEK"
MISTRAL = "MISTRAL"
CUSTOM = "CUSTOM"
class ModelSource(StrEnum):
DISCOVERED = "DISCOVERED"
MANUAL = "MANUAL"
class LogLevel(StrEnum):
@ -699,11 +635,11 @@ class NewChatThread(BaseModel, TimestampMixin):
default=False,
server_default="false",
)
# Auto (Fastest) model pin for this thread: concrete resolved global LLM
# Auto model pin for this thread: concrete resolved global LLM
# config id. NULL means no pin; Auto will resolve on the next turn.
# Single-writer invariant: only app.services.auto_model_pin_service sets
# or clears this column (plus bulk clears when a search space's
# agent_llm_id changes). Unindexed: all reads are by primary key.
# chat_model_id changes). Unindexed: all reads are by primary key.
pinned_llm_config_id = Column(Integer, nullable=True)
# Surface metadata for first-party SurfSense and external chat threads.
@ -1607,73 +1543,80 @@ class Report(BaseModel, TimestampMixin):
thread = relationship("NewChatThread")
class ImageGenerationConfig(BaseModel, TimestampMixin):
"""
Dedicated configuration table for image generation models.
class Connection(BaseModel, TimestampMixin):
__tablename__ = "connections"
Separate from NewLLMConfig because image generation models don't need
system_instructions, citations_enabled, or use_default_system_instructions.
They only need provider credentials and model parameters.
"""
__tablename__ = "image_generation_configs"
name = Column(String(100), nullable=False, index=True)
description = Column(String(500), nullable=True)
# Provider & model (uses ImageGenProvider, NOT LiteLLMProvider)
provider = Column(SQLAlchemyEnum(ImageGenProvider), nullable=False)
custom_provider = Column(String(100), nullable=True)
model_name = Column(String(100), nullable=False)
# Credentials
api_key = Column(String, nullable=False)
api_base = Column(String(500), nullable=True)
api_version = Column(String(50), nullable=True) # Azure-specific
# Additional litellm parameters
litellm_params = Column(JSON, nullable=True, default={})
# Relationships
search_space_id = Column(
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
)
search_space = relationship(
"SearchSpace", back_populates="image_generation_configs"
)
# User who created this config
user_id = Column(
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
)
user = relationship("User", back_populates="image_generation_configs")
class VisionLLMConfig(BaseModel, TimestampMixin):
__tablename__ = "vision_llm_configs"
name = Column(String(100), nullable=False, index=True)
description = Column(String(500), nullable=True)
provider = Column(SQLAlchemyEnum(VisionProvider), nullable=False)
custom_provider = Column(String(100), nullable=True)
model_name = Column(String(100), nullable=False)
api_key = Column(String, nullable=False)
api_base = Column(String(500), nullable=True)
api_version = Column(String(50), nullable=True)
litellm_params = Column(JSON, nullable=True, default={})
provider = Column(String(100), nullable=False, index=True)
base_url = Column(String(500), nullable=True)
api_key = Column(String, nullable=True)
extra = Column(JSONB, nullable=False, default=dict, server_default="{}")
scope = Column(SQLAlchemyEnum(ConnectionScope), nullable=False, index=True)
enabled = Column(Boolean, nullable=False, default=True, server_default="true")
search_space_id = Column(
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=True
)
search_space = relationship("SearchSpace", back_populates="vision_llm_configs")
user_id = Column(
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
search_space = relationship("SearchSpace", back_populates="connections")
user = relationship("User", back_populates="connections")
models = relationship(
"Model",
back_populates="connection",
order_by="Model.id",
cascade="all, delete-orphan",
passive_deletes=True,
)
__table_args__ = (
CheckConstraint(
"(scope = 'GLOBAL' AND search_space_id IS NULL AND user_id IS NULL) OR "
"(scope = 'SEARCH_SPACE' AND search_space_id IS NOT NULL AND user_id IS NOT NULL) OR "
"(scope = 'USER' AND user_id IS NOT NULL)",
name="ck_connections_scope_owner",
),
)
class Model(BaseModel, TimestampMixin):
__tablename__ = "models"
connection_id = Column(
Integer,
ForeignKey("connections.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
model_id = Column(String(255), nullable=False)
display_name = Column(String(255), nullable=True)
source = Column(
SQLAlchemyEnum(ModelSource),
nullable=False,
default=ModelSource.DISCOVERED,
server_default=ModelSource.DISCOVERED.value,
)
supports_chat = Column(Boolean, nullable=True)
max_input_tokens = Column(Integer, nullable=True)
supports_image_input = Column(Boolean, nullable=True)
supports_tools = Column(Boolean, nullable=True)
supports_image_generation = Column(Boolean, nullable=True)
capabilities_override = Column(
JSONB, nullable=False, default=dict, server_default="{}"
)
enabled = Column(Boolean, nullable=False, default=True, server_default="true")
billing_tier = Column(String(50), nullable=True, index=True)
catalog = Column(JSONB, nullable=False, default=dict, server_default="{}")
connection = relationship("Connection", back_populates="models")
__table_args__ = (
UniqueConstraint(
"connection_id", "model_id", name="uq_models_connection_model_id"
),
Index("ix_models_model_id", "model_id"),
)
user = relationship("User", back_populates="vision_llm_configs")
class ImageGeneration(BaseModel, TimestampMixin):
@ -1707,10 +1650,9 @@ class ImageGeneration(BaseModel, TimestampMixin):
style = Column(String(50), nullable=True) # Model-specific style parameter
response_format = Column(String(50), nullable=True) # "url" or "b64_json"
# Image generation config reference
# 0 = Auto mode (router), negative IDs = global configs from YAML,
# positive IDs = ImageGenerationConfig records in DB
image_generation_config_id = Column(Integer, nullable=True)
# Image generation model provenance.
# 0 = Auto mode, negative IDs = GLOBAL models, positive IDs = Model records.
image_gen_model_id = Column(Integer, nullable=True)
# Response data (full litellm response as JSONB) — present on success
response_data = Column(JSONB, nullable=True)
@ -1752,19 +1694,19 @@ class SearchSpace(BaseModel, TimestampMixin):
shared_memory_md = Column(Text, nullable=True, server_default="")
# Search space-level LLM preferences (shared by all members)
# Note: ID values:
# - 0: Auto mode (uses LiteLLM Router for load balancing) - default for new search spaces
# - Negative IDs: Global configs from YAML
# - Positive IDs: Custom configs from DB (NewLLMConfig table)
agent_llm_id = Column(
Integer, nullable=True, default=0
# Connection/model role bindings.
# Note: ID values preserve the existing convention:
# - 0: Auto mode
# - Negative IDs: Global virtual models from global_llm_config.yaml
# - Positive IDs: User/search-space models from the models table
chat_model_id = Column(
Integer, nullable=True, default=0, server_default="0"
) # For agent/chat operations, defaults to Auto mode
image_generation_config_id = Column(
Integer, nullable=True, default=0
) # For image generation, defaults to Auto mode
vision_llm_config_id = Column(
Integer, nullable=True, default=0
image_gen_model_id = Column(
Integer, nullable=True, default=0, server_default="0"
) # For image generation, defaults to Auto mode when eligible
vision_model_id = Column(
Integer, nullable=True, default=0, server_default="0"
) # For vision/screenshot analysis, defaults to Auto mode
ai_file_sort_enabled = Column(
@ -1836,23 +1778,12 @@ class SearchSpace(BaseModel, TimestampMixin):
order_by="SearchSourceConnector.id",
cascade="all, delete-orphan",
)
new_llm_configs = relationship(
"NewLLMConfig",
connections = relationship(
"Connection",
back_populates="search_space",
order_by="NewLLMConfig.id",
cascade="all, delete-orphan",
)
image_generation_configs = relationship(
"ImageGenerationConfig",
back_populates="search_space",
order_by="ImageGenerationConfig.id",
cascade="all, delete-orphan",
)
vision_llm_configs = relationship(
"VisionLLMConfig",
back_populates="search_space",
order_by="VisionLLMConfig.id",
order_by="Connection.id",
cascade="all, delete-orphan",
passive_deletes=True,
)
automations = relationship(
@ -1955,64 +1886,6 @@ class SearchSourceConnector(BaseModel, TimestampMixin):
documents = relationship("Document", back_populates="connector")
class NewLLMConfig(BaseModel, TimestampMixin):
"""
New LLM configuration table that combines model settings with prompt configuration.
This table provides:
- LLM model configuration (provider, model_name, api_key, etc.)
- Configurable system instructions (defaults to SURFSENSE_SYSTEM_INSTRUCTIONS)
- Citation toggle (enable/disable citation instructions)
Note: Tools instructions are built by get_tools_instructions(thread_visibility) (personal vs shared memory).
"""
__tablename__ = "new_llm_configs"
name = Column(String(100), nullable=False, index=True)
description = Column(String(500), nullable=True)
# === LLM Model Configuration (from original LLMConfig, excluding 'language') ===
# Provider from the enum
provider = Column(SQLAlchemyEnum(LiteLLMProvider), nullable=False)
# Custom provider name when provider is CUSTOM
custom_provider = Column(String(100), nullable=True)
# Just the model name without provider prefix
model_name = Column(String(100), nullable=False)
# API Key should be encrypted before storing
api_key = Column(String, nullable=False)
api_base = Column(String(500), nullable=True)
# For any other parameters that litellm supports
litellm_params = Column(JSON, nullable=True, default={})
# === Prompt Configuration ===
# Configurable system instructions (defaults to SURFSENSE_SYSTEM_INSTRUCTIONS)
# Users can customize this from the UI
system_instructions = Column(
Text,
nullable=False,
default="", # Empty string means use default SURFSENSE_SYSTEM_INSTRUCTIONS
)
# Whether to use the default system instructions when system_instructions is empty
use_default_system_instructions = Column(Boolean, nullable=False, default=True)
# Citation toggle - when enabled, SURFSENSE_CITATION_INSTRUCTIONS is injected
# When disabled, an anti-citation prompt is injected instead
citations_enabled = Column(Boolean, nullable=False, default=True)
# === Relationships ===
search_space_id = Column(
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
)
search_space = relationship("SearchSpace", back_populates="new_llm_configs")
# User who created this config
user_id = Column(
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
)
user = relationship("User", back_populates="new_llm_configs")
class Log(BaseModel, TimestampMixin):
__tablename__ = "logs"
@ -2379,22 +2252,8 @@ if config.AUTH_TYPE == "GOOGLE":
passive_deletes=True,
)
# LLM configs created by this user
new_llm_configs = relationship(
"NewLLMConfig",
back_populates="user",
passive_deletes=True,
)
# Image generation configs created by this user
image_generation_configs = relationship(
"ImageGenerationConfig",
back_populates="user",
passive_deletes=True,
)
vision_llm_configs = relationship(
"VisionLLMConfig",
connections = relationship(
"Connection",
back_populates="user",
passive_deletes=True,
)
@ -2525,22 +2384,8 @@ else:
passive_deletes=True,
)
# LLM configs created by this user
new_llm_configs = relationship(
"NewLLMConfig",
back_populates="user",
passive_deletes=True,
)
# Image generation configs created by this user
image_generation_configs = relationship(
"ImageGenerationConfig",
back_populates="user",
passive_deletes=True,
)
vision_llm_configs = relationship(
"VisionLLMConfig",
connections = relationship(
"Connection",
back_populates="user",
passive_deletes=True,
)

View file

@ -14,6 +14,8 @@ from litellm.exceptions import (
)
from sqlalchemy.exc import IntegrityError as IntegrityError
from app.services.llm_error_adapter import LLMErrorCategory, adapt_llm_exception
# Tuples for use directly in except clauses.
RETRYABLE_LLM_ERRORS = (
RateLimitError,
@ -97,38 +99,20 @@ def safe_exception_message(exc: Exception) -> str:
def llm_retryable_message(exc: Exception) -> str:
try:
if isinstance(exc, RateLimitError):
return PipelineMessages.RATE_LIMIT
if isinstance(exc, Timeout):
return PipelineMessages.LLM_TIMEOUT
if isinstance(exc, ServiceUnavailableError):
return PipelineMessages.LLM_UNAVAILABLE
if isinstance(exc, BadGatewayError):
return PipelineMessages.LLM_BAD_GATEWAY
if isinstance(exc, InternalServerError):
return PipelineMessages.LLM_SERVER_ERROR
if isinstance(exc, APIConnectionError):
return PipelineMessages.LLM_CONNECTION
return safe_exception_message(exc)
adapted = adapt_llm_exception(exc)
if adapted.category is LLMErrorCategory.UNKNOWN:
return safe_exception_message(exc)
return adapted.user_message
except Exception:
return "Something went wrong when calling the LLM."
def llm_permanent_message(exc: Exception) -> str:
try:
if isinstance(exc, AuthenticationError):
return PipelineMessages.LLM_AUTH
if isinstance(exc, PermissionDeniedError):
return PipelineMessages.LLM_PERMISSION
if isinstance(exc, NotFoundError):
return PipelineMessages.LLM_NOT_FOUND
if isinstance(exc, BadRequestError):
return PipelineMessages.LLM_BAD_REQUEST
if isinstance(exc, UnprocessableEntityError):
return PipelineMessages.LLM_UNPROCESSABLE
if isinstance(exc, APIResponseValidationError):
return PipelineMessages.LLM_RESPONSE
return safe_exception_message(exc)
adapted = adapt_llm_exception(exc)
if adapted.category is LLMErrorCategory.UNKNOWN:
return safe_exception_message(exc)
return adapted.user_message
except Exception:
return "Something went wrong when calling the LLM."

View file

@ -82,7 +82,7 @@ def build_configurable_system_prompt(
*,
model_name: str | None = None,
) -> str:
"""Build a configurable SurfSense system prompt (NewLLMConfig path).
"""Build a configurable SurfSense system prompt.
See :func:`app.prompts.system_prompt_composer.composer.compose_system_prompt`
for full parameter docs.
@ -104,7 +104,7 @@ def build_configurable_system_prompt(
def get_default_system_instructions() -> str:
"""Return the default ``<system_instruction>`` block (no tools / citations).
Useful for populating the UI when seeding ``NewLLMConfig.system_instructions``.
Useful for populating the UI when editing custom system instructions.
The output reflects the current fragment tree, not a baked-in constant.
"""
resolved_today = datetime.now(UTC).date().isoformat()

View file

@ -348,8 +348,7 @@ def compose_system_prompt(
mcp_connector_tools: ``{server_name: [tool_names...]}`` to inject
an explicit MCP routing block.
custom_system_instructions: Free-form instructions that override
the default ``<system_instruction>`` block (legacy support
for ``NewLLMConfig.system_instructions``).
the default ``<system_instruction>`` block.
use_default_system_instructions: When ``custom_system_instructions``
is empty/None, fall back to defaults (legacy semantics).
citations_enabled: Include ``citations_on.md`` (true) or

View file

@ -44,9 +44,9 @@ from .logs_routes import router as logs_router
from .luma_add_connector_route import router as luma_add_connector_router
from .mcp_oauth_route import router as mcp_oauth_router
from .memory_routes import router as memory_router
from .model_connections_routes import router as model_connections_router
from .model_list_routes import router as model_list_router
from .new_chat_routes import router as new_chat_router
from .new_llm_config_routes import router as new_llm_config_router
from .notes_routes import router as notes_router
from .notion_add_connector_route import router as notion_add_connector_router
from .obsidian_plugin_routes import router as obsidian_plugin_router
@ -63,7 +63,6 @@ from .stripe_routes import router as stripe_router
from .team_memory_routes import router as team_memory_router
from .teams_add_connector_route import router as teams_add_connector_router
from .video_presentations_routes import router as video_presentations_router
from .vision_llm_routes import router as vision_llm_router
from .youtube_routes import router as youtube_router
router = APIRouter()
@ -98,7 +97,6 @@ router.include_router(
) # Video presentation status and streaming
router.include_router(reports_router) # Report CRUD and multi-format export
router.include_router(image_generation_router) # Image generation via litellm
router.include_router(vision_llm_router) # Vision LLM configs for screenshot analysis
router.include_router(search_source_connectors_router)
router.include_router(google_calendar_add_connector_router)
router.include_router(google_gmail_add_connector_router)
@ -116,7 +114,7 @@ router.include_router(jira_add_connector_router)
router.include_router(confluence_add_connector_router)
router.include_router(clickup_add_connector_router)
router.include_router(dropbox_add_connector_router)
router.include_router(new_llm_config_router) # LLM configs with prompt configuration
router.include_router(model_connections_router) # Connection-centric model catalog
router.include_router(model_list_router) # Dynamic model catalogue from OpenRouter
router.include_router(logs_router)
router.include_router(circleback_webhook_router) # Circleback meeting webhooks

View file

@ -18,6 +18,7 @@ from app.etl_pipeline.file_classifier import (
PLAINTEXT_EXTENSIONS,
)
from app.rate_limiter import limiter
from app.tasks.chat.streaming.errors.classifier import classify_stream_exception
logger = logging.getLogger(__name__)
@ -98,7 +99,6 @@ class AnonQuotaResponse(BaseModel):
class AnonModelResponse(BaseModel):
id: int
name: str
description: str | None = None
provider: str
model_name: str
billing_tier: str = "free"
@ -131,8 +131,7 @@ async def list_anonymous_models():
AnonModelResponse(
id=cfg.get("id", 0),
name=cfg.get("name", ""),
description=cfg.get("description"),
provider=cfg.get("provider", ""),
provider=cfg.get("provider") or cfg.get("litellm_provider", ""),
model_name=cfg.get("model_name", ""),
billing_tier=cfg.get("billing_tier", "free"),
is_premium=cfg.get("billing_tier", "free") == "premium",
@ -160,8 +159,7 @@ async def get_anonymous_model(slug: str):
return AnonModelResponse(
id=cfg.get("id", 0),
name=cfg.get("name", ""),
description=cfg.get("description"),
provider=cfg.get("provider", ""),
provider=cfg.get("provider") or cfg.get("litellm_provider", ""),
model_name=cfg.get("model_name", ""),
billing_tier=cfg.get("billing_tier", "free"),
is_premium=cfg.get("billing_tier", "free") == "premium",
@ -474,7 +472,15 @@ async def stream_anonymous_chat(
except Exception as e:
logger.exception("Anonymous chat stream error")
await TokenQuotaService.anon_release(session_key, ip_key, request_id)
yield streaming_service.format_error(f"Error during chat: {e!s}")
_, error_code, _, _, user_message, extra = classify_stream_exception(
e,
flow_label="chat",
)
yield streaming_service.format_error(
user_message,
error_code=error_code,
extra=extra,
)
yield streaming_service.format_done()
finally:
await TokenQuotaService.anon_release_stream_slot(client_ip)

View file

@ -1,7 +1,5 @@
"""
Image Generation routes:
- CRUD for ImageGenerationConfig (user-created image model configs)
- Global image gen configs endpoint (from YAML)
- Image generation execution (calls litellm.aimage_generation())
- CRUD for ImageGeneration records (results)
- Image serving endpoint (serves b64_json images from DB, protected by signed tokens)
@ -16,11 +14,12 @@ from litellm import aimage_generation
from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.config import config
from app.db import (
ImageGeneration,
ImageGenerationConfig,
Model,
Permission,
SearchSpace,
SearchSpaceMembership,
@ -28,14 +27,14 @@ from app.db import (
get_async_session,
)
from app.schemas import (
GlobalImageGenConfigRead,
ImageGenerationConfigCreate,
ImageGenerationConfigRead,
ImageGenerationConfigUpdate,
ImageGenerationCreate,
ImageGenerationListRead,
ImageGenerationRead,
)
from app.services.auto_model_pin_service import (
auto_model_candidates,
choose_auto_model_candidate,
)
from app.services.billable_calls import (
DEFAULT_IMAGE_RESERVE_MICROS,
QuotaInsufficientError,
@ -43,10 +42,10 @@ from app.services.billable_calls import (
)
from app.services.image_gen_router_service import (
IMAGE_GEN_AUTO_MODE_ID,
ImageGenRouterService,
is_image_gen_auto_mode,
)
from app.services.provider_api_base import resolve_api_base
from app.services.model_capabilities import has_capability
from app.services.model_resolver import to_litellm
from app.users import current_active_user
from app.utils.rbac import check_permission
from app.utils.signed_image_urls import verify_image_token
@ -54,52 +53,16 @@ from app.utils.signed_image_urls import verify_image_token
router = APIRouter()
logger = logging.getLogger(__name__)
# Provider mapping for building litellm model strings.
# Only includes providers that support image generation.
# See: https://docs.litellm.ai/docs/image_generation#supported-providers
_PROVIDER_MAP = {
"OPENAI": "openai",
"AZURE_OPENAI": "azure",
"GOOGLE": "gemini", # Google AI Studio
"VERTEX_AI": "vertex_ai",
"BEDROCK": "bedrock", # AWS Bedrock
"RECRAFT": "recraft",
"OPENROUTER": "openrouter",
"XINFERENCE": "xinference",
"NSCALE": "nscale",
}
def _get_global_model(model_id: int) -> dict | None:
return next((m for m in config.GLOBAL_MODELS if m.get("id") == model_id), None)
def _get_global_image_gen_config(config_id: int) -> dict | None:
"""Get a global image generation configuration by ID (negative IDs)."""
if config_id == IMAGE_GEN_AUTO_MODE_ID:
return {
"id": IMAGE_GEN_AUTO_MODE_ID,
"name": "Auto (Fastest)",
"provider": "AUTO",
"model_name": "auto",
"is_auto_mode": True,
}
if config_id > 0:
return None
for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS:
if cfg.get("id") == config_id:
return cfg
return None
def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str:
"""Resolve the LiteLLM provider prefix used in model strings."""
if custom_provider:
return custom_provider
return _PROVIDER_MAP.get(provider.upper(), provider.lower())
def _build_model_string(
provider: str, model_name: str, custom_provider: str | None
) -> str:
"""Build a litellm model string from provider + model_name."""
return f"{_resolve_provider_prefix(provider, custom_provider)}/{model_name}"
def _get_global_connection(connection_id: int) -> dict | None:
return next(
(c for c in config.GLOBAL_CONNECTIONS if c.get("id") == connection_id),
None,
)
async def _resolve_billing_for_image_gen(
@ -115,34 +78,41 @@ async def _resolve_billing_for_image_gen(
config that will actually run, and so we don't open an
``ImageGeneration`` row for a request that's about to 402.
User-owned (positive ID) BYOK configs are always free they cost
the user nothing on our side. Auto mode currently treats as free
because the underlying router can dispatch to either premium or
free YAML configs and we don't surface the resolved deployment up
here yet. Bringing Auto under premium billing would require
threading the chosen deployment back from ``ImageGenRouterService``.
User-owned (positive ID) BYOK models are always free they cost
the user nothing on our side. Auto mode resolves to one concrete
global or BYOK model before billing is calculated.
"""
resolved_id = config_id
if resolved_id is None:
resolved_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID
resolved_id = search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID
if is_image_gen_auto_mode(resolved_id):
return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS)
candidates = await auto_model_candidates(
session,
search_space_id=search_space.id,
user_id=search_space.user_id,
capability="image_gen",
)
if not candidates:
return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS)
selected = choose_auto_model_candidate(candidates, search_space.id)
resolved_id = int(selected["id"])
if resolved_id < 0:
cfg = _get_global_image_gen_config(resolved_id) or {}
billing_tier = str(cfg.get("billing_tier", "free")).lower()
base_model = _build_model_string(
cfg.get("provider", ""),
cfg.get("model_name", ""),
cfg.get("custom_provider"),
)
global_model = _get_global_model(resolved_id) or {}
global_connection = _get_global_connection(global_model.get("connection_id", 0))
billing_tier = str(global_model.get("billing_tier", "free")).lower()
if global_connection and global_model.get("model_id"):
base_model, _ = to_litellm(global_connection, global_model["model_id"])
else:
base_model = "global_image_model"
catalog = global_model.get("catalog") or {}
reserve_micros = int(
cfg.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS
catalog.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS
)
return (billing_tier, base_model, reserve_micros)
# Positive ID = user-owned BYOK image-gen config — always free.
# Positive ID = user-owned BYOK image-gen model — always free.
return ("free", "user_byok", DEFAULT_IMAGE_RESERVE_MICROS)
@ -155,14 +125,14 @@ async def _execute_image_generation(
Call litellm.aimage_generation() with the appropriate config.
Resolution order:
1. Explicit image_generation_config_id on the request
2. Search space's image_generation_config_id preference
1. Explicit image_gen_model_id on the request
2. Search space's image_gen_model_id preference
3. Falls back to Auto mode if available
"""
config_id = image_gen.image_generation_config_id
config_id = image_gen.image_gen_model_id
if config_id is None:
config_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID
image_gen.image_generation_config_id = config_id
config_id = search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID
image_gen.image_gen_model_id = config_id
# Build kwargs
gen_kwargs = {}
@ -178,36 +148,30 @@ async def _execute_image_generation(
gen_kwargs["response_format"] = image_gen.response_format
if is_image_gen_auto_mode(config_id):
if not ImageGenRouterService.is_initialized():
raise ValueError(
"Auto mode requested but Image Generation Router not initialized. "
"Ensure global_llm_config.yaml has global_image_generation_configs."
)
response = await ImageGenRouterService.aimage_generation(
prompt=image_gen.prompt, model="auto", **gen_kwargs
candidates = await auto_model_candidates(
session,
search_space_id=search_space.id,
user_id=search_space.user_id,
capability="image_gen",
)
elif config_id < 0:
# Global config from YAML
cfg = _get_global_image_gen_config(config_id)
if not cfg:
raise ValueError(f"Global image generation config {config_id} not found")
if not candidates:
raise ValueError("No image-generation models are available for Auto mode")
config_id = int(choose_auto_model_candidate(candidates, search_space.id)["id"])
image_gen.image_gen_model_id = config_id
provider_prefix = _resolve_provider_prefix(
cfg.get("provider", ""), cfg.get("custom_provider")
if config_id < 0:
global_model = _get_global_model(config_id)
if not global_model or not has_capability(global_model, "image_gen"):
raise ValueError(f"Global image generation model {config_id} not found")
global_connection = _get_global_connection(global_model["connection_id"])
if not global_connection:
raise ValueError(f"Global connection for image model {config_id} not found")
model_string, resolved_kwargs = to_litellm(
global_connection,
global_model["model_id"],
)
model_string = f"{provider_prefix}/{cfg['model_name']}"
gen_kwargs["api_key"] = cfg.get("api_key")
api_base = resolve_api_base(
provider=cfg.get("provider"),
provider_prefix=provider_prefix,
config_api_base=cfg.get("api_base"),
)
if api_base:
gen_kwargs["api_base"] = api_base
if cfg.get("api_version"):
gen_kwargs["api_version"] = cfg["api_version"]
if cfg.get("litellm_params"):
gen_kwargs.update(cfg["litellm_params"])
gen_kwargs.update(resolved_kwargs)
# User model override
if image_gen.model:
@ -217,30 +181,28 @@ async def _execute_image_generation(
prompt=image_gen.prompt, model=model_string, **gen_kwargs
)
else:
# Positive ID = DB ImageGenerationConfig
# Positive ID = Model + Connection
result = await session.execute(
select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id)
select(Model)
.options(selectinload(Model.connection))
.filter(Model.id == config_id, Model.enabled.is_(True))
)
db_cfg = result.scalars().first()
if not db_cfg:
raise ValueError(f"Image generation config {config_id} not found")
db_model = result.scalars().first()
if not db_model or not db_model.connection or not db_model.connection.enabled:
raise ValueError(f"Image generation model {config_id} not found")
conn = db_model.connection
if conn.search_space_id is not None and conn.search_space_id != search_space.id:
raise ValueError(f"Image generation model {config_id} not found")
if conn.user_id is not None and conn.user_id != search_space.user_id:
raise ValueError(f"Image generation model {config_id} not found")
if not has_capability(db_model, "image_gen"):
raise ValueError(f"Model {config_id} is not image-generation capable")
provider_prefix = _resolve_provider_prefix(
db_cfg.provider.value, db_cfg.custom_provider
model_string, resolved_kwargs = to_litellm(
db_model.connection,
db_model.model_id,
)
model_string = f"{provider_prefix}/{db_cfg.model_name}"
gen_kwargs["api_key"] = db_cfg.api_key
api_base = resolve_api_base(
provider=db_cfg.provider.value,
provider_prefix=provider_prefix,
config_api_base=db_cfg.api_base,
)
if api_base:
gen_kwargs["api_base"] = api_base
if db_cfg.api_version:
gen_kwargs["api_version"] = db_cfg.api_version
if db_cfg.litellm_params:
gen_kwargs.update(db_cfg.litellm_params)
gen_kwargs.update(resolved_kwargs)
# User model override
if image_gen.model:
@ -260,266 +222,6 @@ async def _execute_image_generation(
image_gen.model = hidden["model"]
# =============================================================================
# Global Image Generation Configs (from YAML)
# =============================================================================
@router.get(
"/global-image-generation-configs",
response_model=list[GlobalImageGenConfigRead],
)
async def get_global_image_gen_configs(
user: User = Depends(current_active_user),
):
"""Get all global image generation configs. API keys are hidden."""
try:
global_configs = config.GLOBAL_IMAGE_GEN_CONFIGS
safe_configs = []
if global_configs and len(global_configs) > 0:
safe_configs.append(
{
"id": 0,
"name": "Auto (Fastest)",
"description": "Automatically routes across available image generation providers.",
"provider": "AUTO",
"custom_provider": None,
"model_name": "auto",
"api_base": None,
"api_version": None,
"litellm_params": {},
"is_global": True,
"is_auto_mode": True,
# Auto mode currently treated as free until per-deployment
# billing-tier surfacing lands (see _resolve_billing_for_image_gen).
"billing_tier": "free",
"is_premium": False,
}
)
for cfg in global_configs:
billing_tier = str(cfg.get("billing_tier", "free")).lower()
safe_configs.append(
{
"id": cfg.get("id"),
"name": cfg.get("name"),
"description": cfg.get("description"),
"provider": cfg.get("provider"),
"custom_provider": cfg.get("custom_provider"),
"model_name": cfg.get("model_name"),
"api_base": cfg.get("api_base") or None,
"api_version": cfg.get("api_version") or None,
"litellm_params": cfg.get("litellm_params", {}),
"is_global": True,
"billing_tier": billing_tier,
# Mirror chat (``new_llm_config_routes``) so the new-chat
# selector's premium badge logic keys off the same
# field across chat / image / vision tabs.
"is_premium": billing_tier == "premium",
"quota_reserve_micros": cfg.get("quota_reserve_micros"),
}
)
return safe_configs
except Exception as e:
logger.exception("Failed to fetch global image generation configs")
raise HTTPException(
status_code=500, detail=f"Failed to fetch configs: {e!s}"
) from e
# =============================================================================
# ImageGenerationConfig CRUD
# =============================================================================
@router.post("/image-generation-configs", response_model=ImageGenerationConfigRead)
async def create_image_gen_config(
config_data: ImageGenerationConfigCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Create a new image generation config for a search space."""
try:
await check_permission(
session,
user,
config_data.search_space_id,
Permission.IMAGE_GENERATIONS_CREATE.value,
"You don't have permission to create image generation configs in this search space",
)
db_config = ImageGenerationConfig(**config_data.model_dump(), user_id=user.id)
session.add(db_config)
await session.commit()
await session.refresh(db_config)
return db_config
except HTTPException:
raise
except Exception as e:
await session.rollback()
logger.exception("Failed to create ImageGenerationConfig")
raise HTTPException(
status_code=500, detail=f"Failed to create config: {e!s}"
) from e
@router.get("/image-generation-configs", response_model=list[ImageGenerationConfigRead])
async def list_image_gen_configs(
search_space_id: int,
skip: int = 0,
limit: int = 100,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""List image generation configs for a search space."""
try:
await check_permission(
session,
user,
search_space_id,
Permission.IMAGE_GENERATIONS_READ.value,
"You don't have permission to view image generation configs in this search space",
)
result = await session.execute(
select(ImageGenerationConfig)
.filter(ImageGenerationConfig.search_space_id == search_space_id)
.order_by(ImageGenerationConfig.created_at.desc())
.offset(skip)
.limit(limit)
)
return result.scalars().all()
except HTTPException:
raise
except Exception as e:
logger.exception("Failed to list ImageGenerationConfigs")
raise HTTPException(
status_code=500, detail=f"Failed to fetch configs: {e!s}"
) from e
@router.get(
"/image-generation-configs/{config_id}", response_model=ImageGenerationConfigRead
)
async def get_image_gen_config(
config_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Get a specific image generation config by ID."""
try:
result = await session.execute(
select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id)
)
db_config = result.scalars().first()
if not db_config:
raise HTTPException(status_code=404, detail="Config not found")
await check_permission(
session,
user,
db_config.search_space_id,
Permission.IMAGE_GENERATIONS_READ.value,
"You don't have permission to view image generation configs in this search space",
)
return db_config
except HTTPException:
raise
except Exception as e:
logger.exception("Failed to get ImageGenerationConfig")
raise HTTPException(
status_code=500, detail=f"Failed to fetch config: {e!s}"
) from e
@router.put(
"/image-generation-configs/{config_id}", response_model=ImageGenerationConfigRead
)
async def update_image_gen_config(
config_id: int,
update_data: ImageGenerationConfigUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Update an existing image generation config."""
try:
result = await session.execute(
select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id)
)
db_config = result.scalars().first()
if not db_config:
raise HTTPException(status_code=404, detail="Config not found")
await check_permission(
session,
user,
db_config.search_space_id,
Permission.IMAGE_GENERATIONS_CREATE.value,
"You don't have permission to update image generation configs in this search space",
)
for key, value in update_data.model_dump(exclude_unset=True).items():
setattr(db_config, key, value)
await session.commit()
await session.refresh(db_config)
return db_config
except HTTPException:
raise
except Exception as e:
await session.rollback()
logger.exception("Failed to update ImageGenerationConfig")
raise HTTPException(
status_code=500, detail=f"Failed to update config: {e!s}"
) from e
@router.delete("/image-generation-configs/{config_id}", response_model=dict)
async def delete_image_gen_config(
config_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Delete an image generation config."""
try:
result = await session.execute(
select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id)
)
db_config = result.scalars().first()
if not db_config:
raise HTTPException(status_code=404, detail="Config not found")
await check_permission(
session,
user,
db_config.search_space_id,
Permission.IMAGE_GENERATIONS_DELETE.value,
"You don't have permission to delete image generation configs in this search space",
)
await session.delete(db_config)
await session.commit()
return {
"message": "Image generation config deleted successfully",
"id": config_id,
}
except HTTPException:
raise
except Exception as e:
await session.rollback()
logger.exception("Failed to delete ImageGenerationConfig")
raise HTTPException(
status_code=500, detail=f"Failed to delete config: {e!s}"
) from e
# =============================================================================
# Image Generation Execution + Results CRUD
# =============================================================================
@ -568,7 +270,7 @@ async def create_image_generation(
raise HTTPException(status_code=404, detail="Search space not found")
billing_tier, base_model, reserve_micros = await _resolve_billing_for_image_gen(
session, data.image_generation_config_id, search_space
session, data.image_gen_model_id, search_space
)
# billable_call runs OUTSIDE the inner try/except so QuotaInsufficientError
@ -594,7 +296,7 @@ async def create_image_generation(
size=data.size,
style=data.style,
response_format=data.response_format,
image_generation_config_id=data.image_generation_config_id,
image_gen_model_id=data.image_gen_model_id,
search_space_id=data.search_space_id,
created_by_id=user.id,
)

View file

@ -0,0 +1,805 @@
import logging
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.config import config
from app.db import (
Connection,
ConnectionScope,
Model,
ModelSource,
NewChatThread,
Permission,
SearchSpace,
User,
get_async_session,
)
from app.schemas import (
ConnectionCreate,
ConnectionRead,
ConnectionUpdate,
ModelCreate,
ModelPreviewRead,
ModelProviderRead,
ModelRead,
ModelRolesRead,
ModelRolesUpdate,
ModelsBulkUpdate,
ModelSelection,
ModelTestPreview,
ModelUpdate,
VerifyConnectionResponse,
)
from app.services.model_capabilities import has_capability
from app.services.model_connection_service import (
ModelDiscoveryError,
derive_capabilities,
discover_models,
test_model,
verify_connection,
)
from app.services.provider_registry import REGISTRY
from app.users import current_active_user
from app.utils.rbac import check_permission
router = APIRouter()
logger = logging.getLogger(__name__)
def _model_read(model: Model | dict) -> ModelRead:
return ModelRead.model_validate(model)
def _preview_model_read(item: dict) -> ModelPreviewRead:
return ModelPreviewRead(
model_id=item["model_id"],
display_name=item.get("display_name"),
source=item.get("source", ModelSource.DISCOVERED),
supports_chat=item.get("supports_chat"),
max_input_tokens=item.get("max_input_tokens"),
supports_image_input=item.get("supports_image_input"),
supports_tools=item.get("supports_tools"),
supports_image_generation=item.get("supports_image_generation"),
enabled=item.get("enabled", False),
metadata=item.get("metadata") or item.get("catalog") or {},
)
def _connection_read(
conn: Connection | dict, models: list[Model | dict] | None = None
) -> ConnectionRead:
if isinstance(conn, dict):
payload = {
**conn,
"has_api_key": bool(conn.get("api_key")),
"api_key": None,
"models": [_model_read(model) for model in (models or [])],
}
payload.pop("api_key", None)
return ConnectionRead.model_validate(payload)
return ConnectionRead(
id=conn.id,
provider=conn.provider,
base_url=conn.base_url,
api_key=conn.api_key,
extra=conn.extra or {},
scope=conn.scope,
search_space_id=conn.search_space_id,
user_id=conn.user_id,
enabled=conn.enabled,
has_api_key=bool(conn.api_key),
models=[_model_read(model) for model in (models or [])],
created_at=conn.created_at,
)
def _apply_model_facts(model: Model, facts: dict) -> None:
model.supports_chat = facts.get("supports_chat")
model.max_input_tokens = facts.get("max_input_tokens")
model.supports_image_input = facts.get("supports_image_input")
model.supports_tools = facts.get("supports_tools")
model.supports_image_generation = facts.get("supports_image_generation")
def _complete_selection_facts(conn: Connection, selection: ModelSelection) -> dict:
facts = selection.model_dump()
derived = derive_capabilities(conn, selection.model_id.strip(), selection.metadata)
for key, value in derived.items():
if facts.get(key) is None:
facts[key] = value
return facts
def _selection_to_model(conn: Connection, selection: ModelSelection) -> Model:
source = (
selection.source
if isinstance(selection.source, ModelSource)
else ModelSource(selection.source)
)
model = Model(
connection_id=conn.id,
model_id=selection.model_id.strip(),
display_name=selection.display_name,
source=source,
capabilities_override={},
enabled=selection.enabled,
catalog=selection.metadata,
)
_apply_model_facts(model, _complete_selection_facts(conn, selection))
return model
def _default_model_for(models: list[Model], capability: str) -> int | None:
for model in models:
if model.enabled and has_capability(model, capability):
return model.id
return None
async def _load_role_model(
session: AsyncSession,
search_space_id: int,
model_id: int,
) -> Model | dict | None:
if model_id < 0:
return next(
(model for model in config.GLOBAL_MODELS if model.get("id") == model_id),
None,
)
result = await session.execute(
select(Model)
.options(selectinload(Model.connection))
.where(Model.id == model_id)
)
model = result.scalars().first()
if model is None or model.connection.search_space_id != search_space_id:
return None
return model
def _role_model_enabled(model: Model | dict) -> bool:
if isinstance(model, dict):
return bool(model.get("enabled", True))
return bool(model.enabled and model.connection.enabled)
async def _validate_role_model_id(
session: AsyncSession,
*,
search_space_id: int,
model_id: int | None,
capability: str,
) -> int:
if model_id is None or model_id == 0:
return 0
model = await _load_role_model(session, search_space_id, model_id)
if model and _role_model_enabled(model) and has_capability(model, capability):
return model_id
raise HTTPException(
status_code=400,
detail=f"Selected model is not available for {capability}",
)
async def _resolve_role_model_id(
session: AsyncSession,
*,
search_space_id: int,
model_id: int | None,
capability: str,
) -> int:
try:
return await _validate_role_model_id(
session,
search_space_id=search_space_id,
model_id=model_id,
capability=capability,
)
except HTTPException:
return 0
async def _clear_invalid_roles(
session: AsyncSession, search_space_id: int
) -> SearchSpace:
search_space = await _get_search_space(session, search_space_id)
search_space.chat_model_id = await _resolve_role_model_id(
session,
search_space_id=search_space_id,
model_id=search_space.chat_model_id,
capability="chat",
)
search_space.vision_model_id = await _resolve_role_model_id(
session,
search_space_id=search_space_id,
model_id=search_space.vision_model_id,
capability="vision",
)
search_space.image_gen_model_id = await _resolve_role_model_id(
session,
search_space_id=search_space_id,
model_id=search_space.image_gen_model_id,
capability="image_gen",
)
return search_space
async def _default_unset_roles(
session: AsyncSession,
conn: Connection,
models: list[Model],
) -> None:
if conn.scope != ConnectionScope.SEARCH_SPACE or conn.search_space_id is None:
return
search_space = await _get_search_space(session, conn.search_space_id)
if search_space.chat_model_id is None:
search_space.chat_model_id = _default_model_for(models, "chat")
if search_space.vision_model_id is None:
vision_default = None
if search_space.chat_model_id:
chat_model = next(
(m for m in models if m.id == search_space.chat_model_id), None
)
if chat_model and has_capability(chat_model, "vision"):
vision_default = chat_model.id
search_space.vision_model_id = vision_default or _default_model_for(
models, "vision"
)
if search_space.image_gen_model_id is None:
search_space.image_gen_model_id = _default_model_for(models, "image_gen")
@router.get("/model-providers", response_model=list[ModelProviderRead])
async def list_model_providers(user: User = Depends(current_active_user)):
del user
local_only = {"ollama_chat", "lm_studio"}
return [
ModelProviderRead(
provider=provider,
transport=spec.transport.value,
discovery=spec.discovery,
default_base_url=spec.default_base_url,
base_url_required=spec.base_url_required,
auth_style=spec.auth_style,
local_only=provider in local_only,
)
for provider, spec in sorted(REGISTRY.items())
]
async def _get_search_space(session: AsyncSession, search_space_id: int) -> SearchSpace:
result = await session.execute(
select(SearchSpace).where(SearchSpace.id == search_space_id)
)
search_space = result.scalars().first()
if not search_space:
raise HTTPException(status_code=404, detail="Search space not found")
return search_space
async def _load_connection(session: AsyncSession, connection_id: int) -> Connection:
result = await session.execute(
select(Connection)
.options(selectinload(Connection.models))
.where(Connection.id == connection_id)
)
conn = result.scalars().first()
if not conn:
raise HTTPException(status_code=404, detail="Connection not found")
return conn
async def _assert_connection_access(
session: AsyncSession,
user: User,
conn: Connection,
permission: str = Permission.LLM_CONFIGS_CREATE.value,
) -> None:
if conn.search_space_id:
await check_permission(
session,
user,
conn.search_space_id,
permission,
"You don't have permission to manage model connections in this search space",
)
return
if conn.user_id != user.id:
raise HTTPException(
status_code=403, detail="Connection does not belong to user"
)
@router.get("/global-model-connections", response_model=list[ConnectionRead])
async def list_global_connections(user: User = Depends(current_active_user)):
del user
models_by_connection: dict[int, list[dict]] = {}
for model in config.GLOBAL_MODELS:
models_by_connection.setdefault(model["connection_id"], []).append(model)
return [
_connection_read(conn, models_by_connection.get(conn["id"], []))
for conn in config.GLOBAL_CONNECTIONS
]
@router.get("/model-connections", response_model=list[ConnectionRead])
async def list_connections(
search_space_id: int | None = None,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
stmt = select(Connection).options(selectinload(Connection.models))
if search_space_id is not None:
await check_permission(
session,
user,
search_space_id,
Permission.LLM_CONFIGS_CREATE.value,
"You don't have permission to view model connections in this search space",
)
stmt = stmt.where(Connection.search_space_id == search_space_id)
else:
stmt = stmt.where(Connection.user_id == user.id)
result = await session.execute(stmt.order_by(Connection.id))
return [
_connection_read(conn, list(conn.models)) for conn in result.scalars().all()
]
@router.post("/model-connections", response_model=ConnectionRead)
async def create_connection(
data: ConnectionCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
if data.scope == ConnectionScope.GLOBAL:
raise HTTPException(status_code=400, detail="GLOBAL connections are YAML-only")
if data.scope == ConnectionScope.SEARCH_SPACE:
if data.search_space_id is None:
raise HTTPException(status_code=400, detail="search_space_id is required")
await check_permission(
session,
user,
data.search_space_id,
Permission.LLM_CONFIGS_CREATE.value,
"You don't have permission to create model connections in this search space",
)
payload = data.model_dump(exclude={"search_space_id", "models"})
conn = Connection(
**payload,
search_space_id=data.search_space_id
if data.scope == ConnectionScope.SEARCH_SPACE
else None,
user_id=user.id,
)
session.add(conn)
await session.flush()
seen_model_ids: set[str] = set()
for selection in data.models:
model_id = selection.model_id.strip()
if not model_id or model_id in seen_model_ids:
continue
seen_model_ids.add(model_id)
session.add(_selection_to_model(conn, selection))
await session.commit()
conn = await _load_connection(session, conn.id)
await _default_unset_roles(session, conn, list(conn.models))
await session.commit()
conn = await _load_connection(session, conn.id)
return _connection_read(conn, list(conn.models))
@router.post(
"/model-connections/discover-preview", response_model=list[ModelPreviewRead]
)
async def preview_connection_models(
data: ConnectionCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
if data.scope == ConnectionScope.SEARCH_SPACE and data.search_space_id is not None:
await check_permission(
session,
user,
data.search_space_id,
Permission.LLM_CONFIGS_CREATE.value,
"You don't have permission to create model connections in this search space",
)
draft = Connection(
provider=data.provider,
base_url=data.base_url,
api_key=data.api_key,
extra=data.extra or {},
scope=data.scope,
enabled=data.enabled,
search_space_id=data.search_space_id
if data.scope == ConnectionScope.SEARCH_SPACE
else None,
user_id=user.id,
)
try:
discovered = await discover_models(draft)
except ModelDiscoveryError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
return [_preview_model_read(item) for item in discovered]
@router.post("/model-connections/test-preview", response_model=VerifyConnectionResponse)
async def test_preview_connection_model(
data: ModelTestPreview,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
if data.scope == ConnectionScope.SEARCH_SPACE and data.search_space_id is not None:
await check_permission(
session,
user,
data.search_space_id,
Permission.LLM_CONFIGS_CREATE.value,
"You don't have permission to create model connections in this search space",
)
model_id = data.model_id.strip()
if not model_id:
raise HTTPException(status_code=400, detail="model_id is required")
draft = Connection(
provider=data.provider,
base_url=data.base_url,
api_key=data.api_key,
extra=data.extra or {},
scope=data.scope,
enabled=data.enabled,
search_space_id=data.search_space_id
if data.scope == ConnectionScope.SEARCH_SPACE
else None,
user_id=user.id,
)
model = Model(
connection_id=0,
model_id=model_id,
source=ModelSource.MANUAL,
enabled=True,
capabilities_override={},
catalog={},
)
result = await test_model(draft, model)
return VerifyConnectionResponse(
status=result.status, ok=result.ok, message=result.message
)
@router.put("/model-connections/{connection_id}", response_model=ConnectionRead)
async def update_connection(
connection_id: int,
data: ConnectionUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
conn = await _load_connection(session, connection_id)
await _assert_connection_access(
session, user, conn, Permission.LLM_CONFIGS_UPDATE.value
)
search_space_id = conn.search_space_id
for key, value in data.model_dump(exclude_unset=True).items():
setattr(conn, key, value)
await session.commit()
if search_space_id is not None:
await _clear_invalid_roles(session, search_space_id)
await session.commit()
conn = await _load_connection(session, connection_id)
return _connection_read(conn, list(conn.models))
@router.delete("/model-connections/{connection_id}")
async def delete_connection(
connection_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
conn = await _load_connection(session, connection_id)
await _assert_connection_access(
session, user, conn, Permission.LLM_CONFIGS_DELETE.value
)
search_space_id = conn.search_space_id
await session.delete(conn)
await session.commit()
if search_space_id is not None:
await _clear_invalid_roles(session, search_space_id)
await session.commit()
return {"status": "deleted"}
@router.post(
"/model-connections/{connection_id}/verify", response_model=VerifyConnectionResponse
)
async def verify_model_connection(
connection_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
conn = await _load_connection(session, connection_id)
await _assert_connection_access(
session, user, conn, Permission.LLM_CONFIGS_CREATE.value
)
result = await verify_connection(conn)
return VerifyConnectionResponse(
status=result.status, ok=result.ok, message=result.message
)
@router.post(
"/model-connections/{connection_id}/discover", response_model=list[ModelRead]
)
async def discover_connection_models(
connection_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
conn = await _load_connection(session, connection_id)
await _assert_connection_access(
session, user, conn, Permission.LLM_CONFIGS_CREATE.value
)
try:
discovered = await discover_models(conn)
except ModelDiscoveryError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
by_model_id = {model.model_id: model for model in conn.models}
for item in discovered:
db_model = by_model_id.get(item["model_id"])
if db_model is None:
db_model = Model(
connection_id=conn.id,
model_id=item["model_id"],
display_name=item.get("display_name"),
source=item["source"],
capabilities_override={},
enabled=False,
catalog=item.get("metadata") or {},
)
_apply_model_facts(db_model, item)
session.add(db_model)
else:
db_model.display_name = item.get("display_name") or db_model.display_name
_apply_model_facts(db_model, item)
db_model.catalog = item.get("metadata") or db_model.catalog
await session.commit()
conn = await _load_connection(session, connection_id)
await _default_unset_roles(session, conn, list(conn.models))
if conn.search_space_id is not None:
await _clear_invalid_roles(session, conn.search_space_id)
await session.commit()
conn = await _load_connection(session, connection_id)
return [_model_read(model) for model in conn.models]
@router.post("/model-connections/{connection_id}/models", response_model=ModelRead)
async def add_manual_model(
connection_id: int,
data: ModelCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
conn = await _load_connection(session, connection_id)
await _assert_connection_access(
session, user, conn, Permission.LLM_CONFIGS_UPDATE.value
)
model_id = data.model_id.strip()
if not model_id:
raise HTTPException(status_code=400, detail="model_id is required")
if any(existing.model_id == model_id for existing in conn.models):
raise HTTPException(
status_code=400, detail="Model already exists on this connection"
)
capabilities = derive_capabilities(conn, model_id)
model = Model(
connection_id=conn.id,
model_id=model_id,
display_name=data.display_name or None,
source=ModelSource.MANUAL,
capabilities_override={},
enabled=True,
catalog={},
)
_apply_model_facts(model, capabilities)
session.add(model)
await session.commit()
await session.refresh(model)
conn = await _load_connection(session, connection_id)
await _default_unset_roles(session, conn, list(conn.models))
if conn.search_space_id is not None:
await _clear_invalid_roles(session, conn.search_space_id)
await session.commit()
await session.refresh(model)
return _model_read(model)
@router.patch(
"/model-connections/{connection_id}/models", response_model=list[ModelRead]
)
async def bulk_update_models(
connection_id: int,
data: ModelsBulkUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
conn = await _load_connection(session, connection_id)
await _assert_connection_access(
session, user, conn, Permission.LLM_CONFIGS_UPDATE.value
)
search_space_id = conn.search_space_id
model_ids = set(data.model_ids)
await session.execute(
update(Model)
.where(Model.connection_id == connection_id, Model.id.in_(model_ids))
.values(enabled=data.enabled)
)
await session.commit()
session.expire_all()
if search_space_id is not None:
await _clear_invalid_roles(session, search_space_id)
await session.commit()
session.expire_all()
result = await session.execute(
select(Model)
.where(Model.connection_id == connection_id, Model.id.in_(model_ids))
.order_by(Model.id)
)
return [_model_read(model) for model in result.scalars().all()]
@router.put("/models/{model_id}", response_model=ModelRead)
async def update_model(
model_id: int,
data: ModelUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
result = await session.execute(
select(Model)
.options(selectinload(Model.connection))
.where(Model.id == model_id)
)
model = result.scalars().first()
if not model:
raise HTTPException(status_code=404, detail="Model not found")
await _assert_connection_access(
session, user, model.connection, Permission.LLM_CONFIGS_UPDATE.value
)
search_space_id = model.connection.search_space_id
update = data.model_dump(exclude_unset=True)
for key, value in update.items():
setattr(model, key, value)
await session.commit()
await session.refresh(model)
if search_space_id is not None:
await _clear_invalid_roles(session, search_space_id)
await session.commit()
await session.refresh(model)
return _model_read(model)
@router.post("/models/{model_id}/test", response_model=VerifyConnectionResponse)
async def test_connection_model(
model_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
result = await session.execute(
select(Model)
.options(selectinload(Model.connection))
.where(Model.id == model_id)
)
model = result.scalars().first()
if not model:
raise HTTPException(status_code=404, detail="Model not found")
await _assert_connection_access(
session, user, model.connection, Permission.LLM_CONFIGS_UPDATE.value
)
result = await test_model(model.connection, model)
await session.commit()
return VerifyConnectionResponse(
status=result.status, ok=result.ok, message=result.message
)
@router.get(
"/search-spaces/{search_space_id}/model-roles", response_model=ModelRolesRead
)
async def get_model_roles(
search_space_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
await check_permission(
session,
user,
search_space_id,
Permission.LLM_CONFIGS_CREATE.value,
"You don't have permission to view model roles in this search space",
)
search_space = await _clear_invalid_roles(session, search_space_id)
await session.commit()
await session.refresh(search_space)
return ModelRolesRead(
chat_model_id=search_space.chat_model_id,
vision_model_id=search_space.vision_model_id,
image_gen_model_id=search_space.image_gen_model_id,
)
@router.put(
"/search-spaces/{search_space_id}/model-roles", response_model=ModelRolesRead
)
async def update_model_roles(
search_space_id: int,
data: ModelRolesUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
await check_permission(
session,
user,
search_space_id,
Permission.LLM_CONFIGS_UPDATE.value,
"You don't have permission to update model roles in this search space",
)
search_space = await _get_search_space(session, search_space_id)
updates = data.model_dump(exclude_unset=True)
if "chat_model_id" in updates:
previous_chat_model_id = search_space.chat_model_id
next_chat_model_id = await _validate_role_model_id(
session,
search_space_id=search_space_id,
model_id=updates["chat_model_id"],
capability="chat",
)
search_space.chat_model_id = next_chat_model_id
if next_chat_model_id != previous_chat_model_id:
await session.execute(
update(NewChatThread)
.where(NewChatThread.search_space_id == search_space_id)
.values(pinned_llm_config_id=None)
)
logger.info(
"Cleared auto model pins for search_space_id=%s after chat_model_id change (%s -> %s)",
search_space_id,
previous_chat_model_id,
next_chat_model_id,
)
if "vision_model_id" in updates:
search_space.vision_model_id = await _validate_role_model_id(
session,
search_space_id=search_space_id,
model_id=updates["vision_model_id"],
capability="vision",
)
if "image_gen_model_id" in updates:
search_space.image_gen_model_id = await _validate_role_model_id(
session,
search_space_id=search_space_id,
model_id=updates["image_gen_model_id"],
capability="image_gen",
)
await session.commit()
await session.refresh(search_space)
return ModelRolesRead(
chat_model_id=search_space.chat_model_id,
vision_model_id=search_space.vision_model_id,
image_gen_model_id=search_space.image_gen_model_id,
)

View file

@ -1741,12 +1741,11 @@ async def handle_new_chat(
if not search_space:
raise HTTPException(status_code=404, detail="Search space not found")
# Use agent_llm_id from search space for chat operations
# Positive IDs load from NewLLMConfig database table
# Negative IDs load from YAML global configs
# Falls back to -1 (first global config) if not configured
# Use the converged model-connections role for chat operations.
# Positive IDs load Model + Connection rows; negative IDs load
# virtual GLOBAL models; 0 means Auto.
llm_config_id = (
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
search_space.chat_model_id if search_space.chat_model_id is not None else 0
)
# Release the read-transaction so we don't hold ACCESS SHARE locks
@ -2228,7 +2227,7 @@ async def regenerate_response(
raise HTTPException(status_code=404, detail="Search space not found")
llm_config_id = (
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
search_space.chat_model_id if search_space.chat_model_id is not None else 0
)
# Release the read-transaction so we don't hold ACCESS SHARE locks
@ -2393,7 +2392,7 @@ async def resume_chat(
raise HTTPException(status_code=404, detail="Search space not found")
llm_config_id = (
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
search_space.chat_model_id if search_space.chat_model_id is not None else 0
)
decisions = [d.model_dump() for d in request.decisions]

View file

@ -1,480 +0,0 @@
"""
API routes for NewLLMConfig CRUD operations.
NewLLMConfig combines model settings with prompt configuration:
- LLM provider, model, API key, etc.
- Configurable system instructions
- Citation toggle
"""
import logging
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.config import config
from app.db import (
NewLLMConfig,
Permission,
User,
get_async_session,
)
from app.prompts.default_system_instructions import get_default_system_instructions
from app.schemas import (
DefaultSystemInstructionsResponse,
GlobalNewLLMConfigRead,
NewLLMConfigCreate,
NewLLMConfigRead,
NewLLMConfigUpdate,
)
from app.services.llm_service import validate_llm_config
from app.services.provider_capabilities import derive_supports_image_input
from app.users import current_active_user
from app.utils.rbac import check_permission
router = APIRouter()
logger = logging.getLogger(__name__)
def _serialize_byok_config(config: NewLLMConfig) -> NewLLMConfigRead:
"""Augment a BYOK chat config row with the derived ``supports_image_input``.
There is no DB column for ``supports_image_input`` the value is
resolved at the API boundary from LiteLLM's authoritative model map
(default-allow on unknown). Returning ``NewLLMConfigRead`` here keeps
the response shape consistent across list / detail / create / update
endpoints without having to remember to set the field at every call
site.
"""
provider_value = (
config.provider.value
if hasattr(config.provider, "value")
else str(config.provider)
)
litellm_params = config.litellm_params or {}
base_model = (
litellm_params.get("base_model") if isinstance(litellm_params, dict) else None
)
supports_image_input = derive_supports_image_input(
provider=provider_value,
model_name=config.model_name,
base_model=base_model,
custom_provider=config.custom_provider,
)
# ``model_validate`` runs the Pydantic conversion using the ORM
# attribute access path enabled by ``ConfigDict(from_attributes=True)``,
# then we layer the derived field on. ``model_copy(update=...)`` keeps
# the surface immutable from the caller's perspective.
base_read = NewLLMConfigRead.model_validate(config)
return base_read.model_copy(update={"supports_image_input": supports_image_input})
# =============================================================================
# Global Configs Routes
# =============================================================================
@router.get("/global-new-llm-configs", response_model=list[GlobalNewLLMConfigRead])
async def get_global_new_llm_configs(
user: User = Depends(current_active_user),
):
"""
Get all available global NewLLMConfig configurations.
These are pre-configured by the system administrator and available to all users.
API keys are not exposed through this endpoint.
Includes:
- Auto mode (ID 0): Uses LiteLLM Router for automatic load balancing
- Global configs (negative IDs): Individual pre-configured LLM providers
"""
try:
global_configs = config.GLOBAL_LLM_CONFIGS
safe_configs = []
# Only include Auto mode if there are actual global configs to route to
# Auto mode requires at least one global config with valid API key
if global_configs and len(global_configs) > 0:
safe_configs.append(
{
"id": 0,
"name": "Auto (Fastest)",
"description": "Automatically routes requests across available LLM providers for optimal performance and rate limit handling. Recommended for most users.",
"provider": "AUTO",
"custom_provider": None,
"model_name": "auto",
"api_base": None,
"litellm_params": {},
"system_instructions": "",
"use_default_system_instructions": True,
"citations_enabled": True,
"is_global": True,
"is_auto_mode": True,
"billing_tier": "free",
"is_premium": False,
"anonymous_enabled": False,
"seo_enabled": False,
"seo_slug": None,
"seo_title": None,
"seo_description": None,
"quota_reserve_tokens": None,
# Auto routes across the configured pool, which usually
# includes at least one vision-capable deployment, so
# treat Auto as image-capable. The router itself will
# still pick a vision-capable deployment for messages
# carrying image_url blocks (LiteLLM Router falls back
# on ``404`` per its ``allowed_fails`` policy).
"supports_image_input": True,
}
)
# Add individual global configs
for cfg in global_configs:
# Capability resolution: explicit value (YAML override or OR
# `_supports_image_input(model)` payload baked in by the
# OpenRouter integration service) wins. Fall back to the
# LiteLLM-driven helper which default-allows on unknown so
# we don't hide vision-capable models that happen to lack a
# YAML annotation. The streaming task safety net is the
# only place a False ever blocks.
if "supports_image_input" in cfg:
supports_image_input = bool(cfg.get("supports_image_input"))
else:
cfg_litellm_params = cfg.get("litellm_params") or {}
cfg_base_model = (
cfg_litellm_params.get("base_model")
if isinstance(cfg_litellm_params, dict)
else None
)
supports_image_input = derive_supports_image_input(
provider=cfg.get("provider"),
model_name=cfg.get("model_name"),
base_model=cfg_base_model,
custom_provider=cfg.get("custom_provider"),
)
safe_config = {
"id": cfg.get("id"),
"name": cfg.get("name"),
"description": cfg.get("description"),
"provider": cfg.get("provider"),
"custom_provider": cfg.get("custom_provider"),
"model_name": cfg.get("model_name"),
"api_base": cfg.get("api_base") or None,
"litellm_params": cfg.get("litellm_params", {}),
# New prompt configuration fields
"system_instructions": cfg.get("system_instructions", ""),
"use_default_system_instructions": cfg.get(
"use_default_system_instructions", True
),
"citations_enabled": cfg.get("citations_enabled", True),
"is_global": True,
"billing_tier": cfg.get("billing_tier", "free"),
"is_premium": cfg.get("billing_tier", "free") == "premium",
"anonymous_enabled": cfg.get("anonymous_enabled", False),
"seo_enabled": cfg.get("seo_enabled", False),
"seo_slug": cfg.get("seo_slug"),
"seo_title": cfg.get("seo_title"),
"seo_description": cfg.get("seo_description"),
"quota_reserve_tokens": cfg.get("quota_reserve_tokens"),
"supports_image_input": supports_image_input,
}
safe_configs.append(safe_config)
return safe_configs
except Exception as e:
logger.exception("Failed to fetch global NewLLMConfigs")
raise HTTPException(
status_code=500, detail=f"Failed to fetch global configurations: {e!s}"
) from e
# =============================================================================
# CRUD Routes
# =============================================================================
@router.post("/new-llm-configs", response_model=NewLLMConfigRead)
async def create_new_llm_config(
config_data: NewLLMConfigCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Create a new NewLLMConfig for a search space.
Requires LLM_CONFIGS_CREATE permission.
"""
try:
# Verify user has permission
await check_permission(
session,
user,
config_data.search_space_id,
Permission.LLM_CONFIGS_CREATE.value,
"You don't have permission to create LLM configurations in this search space",
)
# Validate the LLM configuration by making a test API call
is_valid, error_message = await validate_llm_config(
provider=config_data.provider.value,
model_name=config_data.model_name,
api_key=config_data.api_key,
api_base=config_data.api_base,
custom_provider=config_data.custom_provider,
litellm_params=config_data.litellm_params,
)
if not is_valid:
raise HTTPException(
status_code=400,
detail=f"Invalid LLM configuration: {error_message}",
)
# Create the config with user association
db_config = NewLLMConfig(**config_data.model_dump(), user_id=user.id)
session.add(db_config)
await session.commit()
await session.refresh(db_config)
return _serialize_byok_config(db_config)
except HTTPException:
raise
except Exception as e:
await session.rollback()
logger.exception("Failed to create NewLLMConfig")
raise HTTPException(
status_code=500, detail=f"Failed to create configuration: {e!s}"
) from e
@router.get("/new-llm-configs", response_model=list[NewLLMConfigRead])
async def list_new_llm_configs(
search_space_id: int,
skip: int = 0,
limit: int = 100,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Get all NewLLMConfigs for a search space.
Requires LLM_CONFIGS_READ permission.
"""
try:
# Verify user has permission
await check_permission(
session,
user,
search_space_id,
Permission.LLM_CONFIGS_READ.value,
"You don't have permission to view LLM configurations in this search space",
)
result = await session.execute(
select(NewLLMConfig)
.filter(NewLLMConfig.search_space_id == search_space_id)
.order_by(NewLLMConfig.created_at.desc())
.offset(skip)
.limit(limit)
)
return [_serialize_byok_config(cfg) for cfg in result.scalars().all()]
except HTTPException:
raise
except Exception as e:
logger.exception("Failed to list NewLLMConfigs")
raise HTTPException(
status_code=500, detail=f"Failed to fetch configurations: {e!s}"
) from e
@router.get(
"/new-llm-configs/default-system-instructions",
response_model=DefaultSystemInstructionsResponse,
)
async def get_default_system_instructions_endpoint(
user: User = Depends(current_active_user),
):
"""
Get the default SURFSENSE_SYSTEM_INSTRUCTIONS template.
Useful for pre-populating the UI when creating a new configuration.
"""
return DefaultSystemInstructionsResponse(
default_system_instructions=get_default_system_instructions()
)
@router.get("/new-llm-configs/{config_id}", response_model=NewLLMConfigRead)
async def get_new_llm_config(
config_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Get a specific NewLLMConfig by ID.
Requires LLM_CONFIGS_READ permission.
"""
try:
result = await session.execute(
select(NewLLMConfig).filter(NewLLMConfig.id == config_id)
)
config = result.scalars().first()
if not config:
raise HTTPException(status_code=404, detail="Configuration not found")
# Verify user has permission
await check_permission(
session,
user,
config.search_space_id,
Permission.LLM_CONFIGS_READ.value,
"You don't have permission to view LLM configurations in this search space",
)
return _serialize_byok_config(config)
except HTTPException:
raise
except Exception as e:
logger.exception("Failed to get NewLLMConfig")
raise HTTPException(
status_code=500, detail=f"Failed to fetch configuration: {e!s}"
) from e
@router.put("/new-llm-configs/{config_id}", response_model=NewLLMConfigRead)
async def update_new_llm_config(
config_id: int,
update_data: NewLLMConfigUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Update an existing NewLLMConfig.
Requires LLM_CONFIGS_UPDATE permission.
"""
try:
result = await session.execute(
select(NewLLMConfig).filter(NewLLMConfig.id == config_id)
)
config = result.scalars().first()
if not config:
raise HTTPException(status_code=404, detail="Configuration not found")
# Verify user has permission
await check_permission(
session,
user,
config.search_space_id,
Permission.LLM_CONFIGS_UPDATE.value,
"You don't have permission to update LLM configurations in this search space",
)
update_dict = update_data.model_dump(exclude_unset=True)
# If updating LLM settings, validate them
if any(
key in update_dict
for key in [
"provider",
"model_name",
"api_key",
"api_base",
"custom_provider",
"litellm_params",
]
):
# Build the validation config from existing + updates
validation_config = {
"provider": update_dict.get("provider", config.provider).value
if hasattr(update_dict.get("provider", config.provider), "value")
else update_dict.get("provider", config.provider.value),
"model_name": update_dict.get("model_name", config.model_name),
"api_key": update_dict.get("api_key", config.api_key),
"api_base": update_dict.get("api_base", config.api_base),
"custom_provider": update_dict.get(
"custom_provider", config.custom_provider
),
"litellm_params": update_dict.get(
"litellm_params", config.litellm_params
),
}
is_valid, error_message = await validate_llm_config(
provider=validation_config["provider"],
model_name=validation_config["model_name"],
api_key=validation_config["api_key"],
api_base=validation_config["api_base"],
custom_provider=validation_config["custom_provider"],
litellm_params=validation_config["litellm_params"],
)
if not is_valid:
raise HTTPException(
status_code=400,
detail=f"Invalid LLM configuration: {error_message}",
)
# Apply updates
for key, value in update_dict.items():
setattr(config, key, value)
await session.commit()
await session.refresh(config)
return _serialize_byok_config(config)
except HTTPException:
raise
except Exception as e:
await session.rollback()
logger.exception("Failed to update NewLLMConfig")
raise HTTPException(
status_code=500, detail=f"Failed to update configuration: {e!s}"
) from e
@router.delete("/new-llm-configs/{config_id}", response_model=dict)
async def delete_new_llm_config(
config_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Delete a NewLLMConfig.
Requires LLM_CONFIGS_DELETE permission.
"""
try:
result = await session.execute(
select(NewLLMConfig).filter(NewLLMConfig.id == config_id)
)
config = result.scalars().first()
if not config:
raise HTTPException(status_code=404, detail="Configuration not found")
# Verify user has permission
await check_permission(
session,
user,
config.search_space_id,
Permission.LLM_CONFIGS_DELETE.value,
"You don't have permission to delete LLM configurations in this search space",
)
await session.delete(config)
await session.commit()
return {"message": "Configuration deleted successfully", "id": config_id}
except HTTPException:
raise
except Exception as e:
await session.rollback()
logger.exception("Failed to delete NewLLMConfig")
raise HTTPException(
status_code=500, detail=f"Failed to delete configuration: {e!s}"
) from e

View file

@ -1,27 +1,20 @@
import logging
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import func, update
from sqlalchemy import func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.config import config
from app.db import (
ImageGenerationConfig,
NewChatThread,
NewLLMConfig,
Permission,
SearchSpace,
SearchSpaceMembership,
SearchSpaceRole,
User,
VisionLLMConfig,
get_async_session,
get_default_roles_config,
)
from app.schemas import (
LLMPreferencesRead,
LLMPreferencesUpdate,
SearchSpaceCreate,
SearchSpaceRead,
SearchSpaceUpdate,
@ -377,357 +370,6 @@ async def delete_search_space(
) from e
# =============================================================================
# LLM Preferences Routes
# =============================================================================
async def _get_llm_config_by_id(
session: AsyncSession, config_id: int | None
) -> dict | None:
"""
Get an LLM config by ID as a dictionary. Returns database config for positive IDs,
global config for negative IDs, Auto mode config for ID 0, or None if ID is None.
"""
if config_id is None:
return None
# Auto mode (ID 0) - uses LiteLLM Router for load balancing
if config_id == 0:
return {
"id": 0,
"name": "Auto (Fastest)",
"description": "Automatically routes requests across available LLM providers for optimal performance and rate limit handling",
"provider": "AUTO",
"custom_provider": None,
"model_name": "auto",
"api_base": None,
"litellm_params": {},
"system_instructions": "",
"use_default_system_instructions": True,
"citations_enabled": True,
"is_global": True,
"is_auto_mode": True,
}
if config_id < 0:
# Global config - find from YAML
global_configs = config.GLOBAL_LLM_CONFIGS
for cfg in global_configs:
if cfg.get("id") == config_id:
return {
"id": cfg.get("id"),
"name": cfg.get("name"),
"description": cfg.get("description"),
"provider": cfg.get("provider"),
"custom_provider": cfg.get("custom_provider"),
"model_name": cfg.get("model_name"),
"api_base": cfg.get("api_base"),
"litellm_params": cfg.get("litellm_params", {}),
"system_instructions": cfg.get("system_instructions", ""),
"use_default_system_instructions": cfg.get(
"use_default_system_instructions", True
),
"citations_enabled": cfg.get("citations_enabled", True),
"is_global": True,
}
return None
else:
# Database config - convert to dict
result = await session.execute(
select(NewLLMConfig).filter(NewLLMConfig.id == config_id)
)
db_config = result.scalars().first()
if db_config:
return {
"id": db_config.id,
"name": db_config.name,
"description": db_config.description,
"provider": db_config.provider.value if db_config.provider else None,
"custom_provider": db_config.custom_provider,
"model_name": db_config.model_name,
"api_key": db_config.api_key,
"api_base": db_config.api_base,
"litellm_params": db_config.litellm_params or {},
"system_instructions": db_config.system_instructions or "",
"use_default_system_instructions": db_config.use_default_system_instructions,
"citations_enabled": db_config.citations_enabled,
"created_at": db_config.created_at.isoformat()
if db_config.created_at
else None,
"search_space_id": db_config.search_space_id,
}
return None
async def _get_image_gen_config_by_id(
session: AsyncSession, config_id: int | None
) -> dict | None:
"""
Get an image generation config by ID as a dictionary.
Returns Auto mode for ID 0, global config for negative IDs,
DB ImageGenerationConfig for positive IDs, or None.
"""
if config_id is None:
return None
if config_id == 0:
return {
"id": 0,
"name": "Auto (Fastest)",
"description": "Automatically routes requests across available image generation providers",
"provider": "AUTO",
"model_name": "auto",
"is_global": True,
"is_auto_mode": True,
"billing_tier": "free",
}
if config_id < 0:
for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS:
if cfg.get("id") == config_id:
return {
"id": cfg.get("id"),
"name": cfg.get("name"),
"description": cfg.get("description"),
"provider": cfg.get("provider"),
"custom_provider": cfg.get("custom_provider"),
"model_name": cfg.get("model_name"),
"api_base": cfg.get("api_base") or None,
"api_version": cfg.get("api_version") or None,
"litellm_params": cfg.get("litellm_params", {}),
"is_global": True,
"billing_tier": cfg.get("billing_tier", "free"),
}
return None
# Positive ID: query ImageGenerationConfig table
result = await session.execute(
select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id)
)
db_config = result.scalars().first()
if db_config:
return {
"id": db_config.id,
"name": db_config.name,
"description": db_config.description,
"provider": db_config.provider.value if db_config.provider else None,
"custom_provider": db_config.custom_provider,
"model_name": db_config.model_name,
"api_base": db_config.api_base,
"api_version": db_config.api_version,
"litellm_params": db_config.litellm_params or {},
"created_at": db_config.created_at.isoformat()
if db_config.created_at
else None,
"search_space_id": db_config.search_space_id,
}
return None
async def _get_vision_llm_config_by_id(
session: AsyncSession, config_id: int | None
) -> dict | None:
if config_id is None:
return None
if config_id == 0:
return {
"id": 0,
"name": "Auto (Fastest)",
"description": "Automatically routes requests across available vision LLM providers",
"provider": "AUTO",
"model_name": "auto",
"is_global": True,
"is_auto_mode": True,
"billing_tier": "free",
}
if config_id < 0:
for cfg in config.GLOBAL_VISION_LLM_CONFIGS:
if cfg.get("id") == config_id:
return {
"id": cfg.get("id"),
"name": cfg.get("name"),
"description": cfg.get("description"),
"provider": cfg.get("provider"),
"custom_provider": cfg.get("custom_provider"),
"model_name": cfg.get("model_name"),
"api_base": cfg.get("api_base") or None,
"api_version": cfg.get("api_version") or None,
"litellm_params": cfg.get("litellm_params", {}),
"is_global": True,
"billing_tier": cfg.get("billing_tier", "free"),
}
return None
result = await session.execute(
select(VisionLLMConfig).filter(VisionLLMConfig.id == config_id)
)
db_config = result.scalars().first()
if db_config:
return {
"id": db_config.id,
"name": db_config.name,
"description": db_config.description,
"provider": db_config.provider.value if db_config.provider else None,
"custom_provider": db_config.custom_provider,
"model_name": db_config.model_name,
"api_base": db_config.api_base,
"api_version": db_config.api_version,
"litellm_params": db_config.litellm_params or {},
"created_at": db_config.created_at.isoformat()
if db_config.created_at
else None,
"search_space_id": db_config.search_space_id,
}
return None
@router.get(
"/search-spaces/{search_space_id}/llm-preferences",
response_model=LLMPreferencesRead,
)
async def get_llm_preferences(
search_space_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Get LLM preferences (role assignments) for a search space.
Requires LLM_CONFIGS_READ permission.
"""
try:
# Check permission
await check_permission(
session,
user,
search_space_id,
Permission.LLM_CONFIGS_READ.value,
"You don't have permission to view LLM preferences",
)
result = await session.execute(
select(SearchSpace).filter(SearchSpace.id == search_space_id)
)
search_space = result.scalars().first()
if not search_space:
raise HTTPException(status_code=404, detail="Search space not found")
# Get full config objects for each role
agent_llm = await _get_llm_config_by_id(session, search_space.agent_llm_id)
image_generation_config = await _get_image_gen_config_by_id(
session, search_space.image_generation_config_id
)
vision_llm_config = await _get_vision_llm_config_by_id(
session, search_space.vision_llm_config_id
)
return LLMPreferencesRead(
agent_llm_id=search_space.agent_llm_id,
image_generation_config_id=search_space.image_generation_config_id,
vision_llm_config_id=search_space.vision_llm_config_id,
agent_llm=agent_llm,
image_generation_config=image_generation_config,
vision_llm_config=vision_llm_config,
)
except HTTPException:
raise
except Exception as e:
logger.exception("Failed to get LLM preferences")
raise HTTPException(
status_code=500, detail=f"Failed to get LLM preferences: {e!s}"
) from e
@router.put(
"/search-spaces/{search_space_id}/llm-preferences",
response_model=LLMPreferencesRead,
)
async def update_llm_preferences(
search_space_id: int,
preferences: LLMPreferencesUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Update LLM preferences (role assignments) for a search space.
Requires LLM_CONFIGS_UPDATE permission.
"""
try:
# Check permission
await check_permission(
session,
user,
search_space_id,
Permission.LLM_CONFIGS_UPDATE.value,
"You don't have permission to update LLM preferences",
)
result = await session.execute(
select(SearchSpace).filter(SearchSpace.id == search_space_id)
)
search_space = result.scalars().first()
if not search_space:
raise HTTPException(status_code=404, detail="Search space not found")
# Update preferences
update_data = preferences.model_dump(exclude_unset=True)
previous_agent_llm_id = search_space.agent_llm_id
for key, value in update_data.items():
setattr(search_space, key, value)
agent_llm_changed = (
"agent_llm_id" in update_data
and update_data["agent_llm_id"] != previous_agent_llm_id
)
if agent_llm_changed:
await session.execute(
update(NewChatThread)
.where(NewChatThread.search_space_id == search_space_id)
.values(pinned_llm_config_id=None)
)
logger.info(
"Cleared auto model pins for search_space_id=%s after agent_llm_id change (%s -> %s)",
search_space_id,
previous_agent_llm_id,
update_data["agent_llm_id"],
)
await session.commit()
await session.refresh(search_space)
# Get full config objects for response
agent_llm = await _get_llm_config_by_id(session, search_space.agent_llm_id)
image_generation_config = await _get_image_gen_config_by_id(
session, search_space.image_generation_config_id
)
vision_llm_config = await _get_vision_llm_config_by_id(
session, search_space.vision_llm_config_id
)
return LLMPreferencesRead(
agent_llm_id=search_space.agent_llm_id,
image_generation_config_id=search_space.image_generation_config_id,
vision_llm_config_id=search_space.vision_llm_config_id,
agent_llm=agent_llm,
image_generation_config=image_generation_config,
vision_llm_config=vision_llm_config,
)
except HTTPException:
raise
except Exception as e:
await session.rollback()
logger.exception("Failed to update LLM preferences")
raise HTTPException(
status_code=500, detail=f"Failed to update LLM preferences: {e!s}"
) from e
@router.get("/searchspaces/{search_space_id}/snapshots")
async def list_search_space_snapshots(
search_space_id: int,

View file

@ -1,304 +0,0 @@
import logging
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.db import (
Permission,
User,
VisionLLMConfig,
get_async_session,
)
from app.schemas import (
GlobalVisionLLMConfigRead,
VisionLLMConfigCreate,
VisionLLMConfigRead,
VisionLLMConfigUpdate,
)
from app.services.vision_model_list_service import get_vision_model_list
from app.users import current_active_user
from app.utils.rbac import check_permission
router = APIRouter()
logger = logging.getLogger(__name__)
# =============================================================================
# Vision Model Catalogue (from OpenRouter, filtered for image-input models)
# =============================================================================
class VisionModelListItem(BaseModel):
value: str
label: str
provider: str
context_window: str | None = None
@router.get("/vision-models", response_model=list[VisionModelListItem])
async def list_vision_models(
user: User = Depends(current_active_user),
):
"""Return vision-capable models sourced from OpenRouter (filtered by image input)."""
try:
return await get_vision_model_list()
except Exception as e:
logger.exception("Failed to fetch vision model list")
raise HTTPException(
status_code=500, detail=f"Failed to fetch vision model list: {e!s}"
) from e
# =============================================================================
# Global Vision LLM Configs (from YAML)
# =============================================================================
@router.get(
"/global-vision-llm-configs",
response_model=list[GlobalVisionLLMConfigRead],
)
async def get_global_vision_llm_configs(
user: User = Depends(current_active_user),
):
try:
global_configs = config.GLOBAL_VISION_LLM_CONFIGS
safe_configs = []
if global_configs and len(global_configs) > 0:
safe_configs.append(
{
"id": 0,
"name": "Auto (Fastest)",
"description": "Automatically routes across available vision LLM providers.",
"provider": "AUTO",
"custom_provider": None,
"model_name": "auto",
"api_base": None,
"api_version": None,
"litellm_params": {},
"is_global": True,
"is_auto_mode": True,
# Auto mode treated as free until per-deployment billing-tier
# surfacing lands; see ``get_vision_llm`` for parity.
"billing_tier": "free",
"is_premium": False,
}
)
for cfg in global_configs:
billing_tier = str(cfg.get("billing_tier", "free")).lower()
safe_configs.append(
{
"id": cfg.get("id"),
"name": cfg.get("name"),
"description": cfg.get("description"),
"provider": cfg.get("provider"),
"custom_provider": cfg.get("custom_provider"),
"model_name": cfg.get("model_name"),
"api_base": cfg.get("api_base") or None,
"api_version": cfg.get("api_version") or None,
"litellm_params": cfg.get("litellm_params", {}),
"is_global": True,
"billing_tier": billing_tier,
# Mirror chat (``new_llm_config_routes``) so the new-chat
# selector's premium badge logic keys off the same
# field across chat / image / vision tabs.
"is_premium": billing_tier == "premium",
"quota_reserve_tokens": cfg.get("quota_reserve_tokens"),
"input_cost_per_token": cfg.get("input_cost_per_token"),
"output_cost_per_token": cfg.get("output_cost_per_token"),
}
)
return safe_configs
except Exception as e:
logger.exception("Failed to fetch global vision LLM configs")
raise HTTPException(
status_code=500, detail=f"Failed to fetch configs: {e!s}"
) from e
# =============================================================================
# VisionLLMConfig CRUD
# =============================================================================
@router.post("/vision-llm-configs", response_model=VisionLLMConfigRead)
async def create_vision_llm_config(
config_data: VisionLLMConfigCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
try:
await check_permission(
session,
user,
config_data.search_space_id,
Permission.VISION_CONFIGS_CREATE.value,
"You don't have permission to create vision LLM configs in this search space",
)
db_config = VisionLLMConfig(**config_data.model_dump(), user_id=user.id)
session.add(db_config)
await session.commit()
await session.refresh(db_config)
return db_config
except HTTPException:
raise
except Exception as e:
await session.rollback()
logger.exception("Failed to create VisionLLMConfig")
raise HTTPException(
status_code=500, detail=f"Failed to create config: {e!s}"
) from e
@router.get("/vision-llm-configs", response_model=list[VisionLLMConfigRead])
async def list_vision_llm_configs(
search_space_id: int,
skip: int = 0,
limit: int = 100,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
try:
await check_permission(
session,
user,
search_space_id,
Permission.VISION_CONFIGS_READ.value,
"You don't have permission to view vision LLM configs in this search space",
)
result = await session.execute(
select(VisionLLMConfig)
.filter(VisionLLMConfig.search_space_id == search_space_id)
.order_by(VisionLLMConfig.created_at.desc())
.offset(skip)
.limit(limit)
)
return result.scalars().all()
except HTTPException:
raise
except Exception as e:
logger.exception("Failed to list VisionLLMConfigs")
raise HTTPException(
status_code=500, detail=f"Failed to fetch configs: {e!s}"
) from e
@router.get("/vision-llm-configs/{config_id}", response_model=VisionLLMConfigRead)
async def get_vision_llm_config(
config_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
try:
result = await session.execute(
select(VisionLLMConfig).filter(VisionLLMConfig.id == config_id)
)
db_config = result.scalars().first()
if not db_config:
raise HTTPException(status_code=404, detail="Config not found")
await check_permission(
session,
user,
db_config.search_space_id,
Permission.VISION_CONFIGS_READ.value,
"You don't have permission to view vision LLM configs in this search space",
)
return db_config
except HTTPException:
raise
except Exception as e:
logger.exception("Failed to get VisionLLMConfig")
raise HTTPException(
status_code=500, detail=f"Failed to fetch config: {e!s}"
) from e
@router.put("/vision-llm-configs/{config_id}", response_model=VisionLLMConfigRead)
async def update_vision_llm_config(
config_id: int,
update_data: VisionLLMConfigUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
try:
result = await session.execute(
select(VisionLLMConfig).filter(VisionLLMConfig.id == config_id)
)
db_config = result.scalars().first()
if not db_config:
raise HTTPException(status_code=404, detail="Config not found")
await check_permission(
session,
user,
db_config.search_space_id,
Permission.VISION_CONFIGS_CREATE.value,
"You don't have permission to update vision LLM configs in this search space",
)
for key, value in update_data.model_dump(exclude_unset=True).items():
setattr(db_config, key, value)
await session.commit()
await session.refresh(db_config)
return db_config
except HTTPException:
raise
except Exception as e:
await session.rollback()
logger.exception("Failed to update VisionLLMConfig")
raise HTTPException(
status_code=500, detail=f"Failed to update config: {e!s}"
) from e
@router.delete("/vision-llm-configs/{config_id}", response_model=dict)
async def delete_vision_llm_config(
config_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
try:
result = await session.execute(
select(VisionLLMConfig).filter(VisionLLMConfig.id == config_id)
)
db_config = result.scalars().first()
if not db_config:
raise HTTPException(status_code=404, detail="Config not found")
await check_permission(
session,
user,
db_config.search_space_id,
Permission.VISION_CONFIGS_DELETE.value,
"You don't have permission to delete vision LLM configs in this search space",
)
await session.delete(db_config)
await session.commit()
return {
"message": "Vision LLM config deleted successfully",
"id": config_id,
}
except HTTPException:
raise
except Exception as e:
await session.rollback()
logger.exception("Failed to delete VisionLLMConfig")
raise HTTPException(
status_code=500, detail=f"Failed to delete config: {e!s}"
) from e

View file

@ -34,16 +34,27 @@ from .folders import (
)
from .google_drive import DriveItem, GoogleDriveIndexingOptions, GoogleDriveIndexRequest
from .image_generation import (
GlobalImageGenConfigRead,
ImageGenerationConfigCreate,
ImageGenerationConfigPublic,
ImageGenerationConfigRead,
ImageGenerationConfigUpdate,
ImageGenerationCreate,
ImageGenerationListRead,
ImageGenerationRead,
)
from .logs import LogBase, LogCreate, LogFilter, LogRead, LogUpdate
from .model_connections import (
ConnectionCreate,
ConnectionRead,
ConnectionUpdate,
ModelCreate,
ModelPreviewRead,
ModelProviderRead,
ModelRead,
ModelRolesRead,
ModelRolesUpdate,
ModelsBulkUpdate,
ModelSelection,
ModelTestPreview,
ModelUpdate,
VerifyConnectionResponse,
)
from .new_chat import (
ChatMessage,
NewChatMessageAppend,
@ -58,16 +69,6 @@ from .new_chat import (
ThreadListItem,
ThreadListResponse,
)
from .new_llm_config import (
DefaultSystemInstructionsResponse,
GlobalNewLLMConfigRead,
LLMPreferencesRead,
LLMPreferencesUpdate,
NewLLMConfigCreate,
NewLLMConfigPublic,
NewLLMConfigRead,
NewLLMConfigUpdate,
)
from .rbac_schemas import (
InviteAcceptRequest,
InviteAcceptResponse,
@ -126,13 +127,6 @@ from .video_presentations import (
VideoPresentationRead,
VideoPresentationUpdate,
)
from .vision_llm import (
GlobalVisionLLMConfigRead,
VisionLLMConfigCreate,
VisionLLMConfigPublic,
VisionLLMConfigRead,
VisionLLMConfigUpdate,
)
__all__ = [
# Folder schemas
@ -144,12 +138,15 @@ __all__ = [
"ChunkCreate",
"ChunkRead",
"ChunkUpdate",
# Model connection schemas
"ConnectionCreate",
"ConnectionRead",
"ConnectionUpdate",
"CreateCreditCheckoutSessionRequest",
"CreateCreditCheckoutSessionResponse",
"CreditPurchaseHistoryResponse",
"CreditPurchaseRead",
"CreditStripeStatusResponse",
"DefaultSystemInstructionsResponse",
# Document schemas
"DocumentBase",
"DocumentMove",
@ -172,19 +169,10 @@ __all__ = [
"FolderRead",
"FolderReorder",
"FolderUpdate",
"GlobalImageGenConfigRead",
"GlobalNewLLMConfigRead",
# Vision LLM Config schemas
"GlobalVisionLLMConfigRead",
"GoogleDriveIndexRequest",
"GoogleDriveIndexingOptions",
# Base schemas
"IDModel",
# Image Generation Config schemas
"ImageGenerationConfigCreate",
"ImageGenerationConfigPublic",
"ImageGenerationConfigRead",
"ImageGenerationConfigUpdate",
# Image Generation schemas
"ImageGenerationCreate",
"ImageGenerationListRead",
@ -196,9 +184,6 @@ __all__ = [
"InviteInfoResponse",
"InviteRead",
"InviteUpdate",
# LLM Preferences schemas
"LLMPreferencesRead",
"LLMPreferencesUpdate",
# Log schemas
"LogBase",
"LogCreate",
@ -217,6 +202,16 @@ __all__ = [
"MembershipRead",
"MembershipReadWithUser",
"MembershipUpdate",
"ModelCreate",
"ModelPreviewRead",
"ModelProviderRead",
"ModelRead",
"ModelRolesRead",
"ModelRolesUpdate",
"ModelSelection",
"ModelTestPreview",
"ModelUpdate",
"ModelsBulkUpdate",
"NewChatMessageAppend",
"NewChatMessageCreate",
"NewChatMessageRead",
@ -225,11 +220,6 @@ __all__ = [
"NewChatThreadRead",
"NewChatThreadUpdate",
"NewChatThreadWithMessages",
# NewLLMConfig schemas
"NewLLMConfigCreate",
"NewLLMConfigPublic",
"NewLLMConfigRead",
"NewLLMConfigUpdate",
"PagePurchaseHistoryResponse",
"PagePurchaseRead",
"PaginatedResponse",
@ -267,13 +257,10 @@ __all__ = [
"UserRead",
"UserSearchSpaceAccess",
"UserUpdate",
"VerifyConnectionResponse",
# Video Presentation schemas
"VideoPresentationBase",
"VideoPresentationCreate",
"VideoPresentationRead",
"VideoPresentationUpdate",
"VisionLLMConfigCreate",
"VisionLLMConfigPublic",
"VisionLLMConfigRead",
"VisionLLMConfigUpdate",
]

View file

@ -1,109 +1,10 @@
"""
Pydantic schemas for Image Generation configs and generation requests.
"""Pydantic schemas for image generation requests/results."""
ImageGenerationConfig: CRUD schemas for user-created image gen model configs.
ImageGeneration: Schemas for the actual image generation requests/results.
GlobalImageGenConfigRead: Schema for admin-configured YAML configs.
"""
import uuid
from datetime import datetime
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from app.db import ImageGenProvider
# =============================================================================
# ImageGenerationConfig CRUD Schemas
# =============================================================================
class ImageGenerationConfigBase(BaseModel):
"""Base schema with fields for ImageGenerationConfig."""
name: str = Field(
..., max_length=100, description="User-friendly name for the config"
)
description: str | None = Field(
None, max_length=500, description="Optional description"
)
provider: ImageGenProvider = Field(
...,
description="Image generation provider (OpenAI, Azure, Google AI Studio, Vertex AI, Bedrock, Recraft, OpenRouter, Xinference, Nscale)",
)
custom_provider: str | None = Field(
None, max_length=100, description="Custom provider name"
)
model_name: str = Field(
..., max_length=100, description="Model name (e.g., dall-e-3, gpt-image-1)"
)
api_key: str = Field(..., description="API key for the provider")
api_base: str | None = Field(
None, max_length=500, description="Optional API base URL"
)
api_version: str | None = Field(
None,
max_length=50,
description="Azure-specific API version (e.g., '2024-02-15-preview')",
)
litellm_params: dict[str, Any] | None = Field(
default=None, description="Additional LiteLLM parameters"
)
class ImageGenerationConfigCreate(ImageGenerationConfigBase):
"""Schema for creating a new ImageGenerationConfig."""
search_space_id: int = Field(
..., description="Search space ID to associate the config with"
)
class ImageGenerationConfigUpdate(BaseModel):
"""Schema for updating an existing ImageGenerationConfig. All fields optional."""
name: str | None = Field(None, max_length=100)
description: str | None = Field(None, max_length=500)
provider: ImageGenProvider | None = None
custom_provider: str | None = Field(None, max_length=100)
model_name: str | None = Field(None, max_length=100)
api_key: str | None = None
api_base: str | None = Field(None, max_length=500)
api_version: str | None = Field(None, max_length=50)
litellm_params: dict[str, Any] | None = None
class ImageGenerationConfigRead(ImageGenerationConfigBase):
"""Schema for reading an ImageGenerationConfig (includes id and timestamps)."""
id: int
created_at: datetime
search_space_id: int
user_id: uuid.UUID
model_config = ConfigDict(from_attributes=True)
class ImageGenerationConfigPublic(BaseModel):
"""Public schema that hides the API key (for list views)."""
id: int
name: str
description: str | None = None
provider: ImageGenProvider
custom_provider: str | None = None
model_name: str
api_base: str | None = None
api_version: str | None = None
litellm_params: dict[str, Any] | None = None
created_at: datetime
search_space_id: int
user_id: uuid.UUID
model_config = ConfigDict(from_attributes=True)
# =============================================================================
# ImageGeneration (request/result) Schemas
# =============================================================================
@ -136,12 +37,12 @@ class ImageGenerationCreate(BaseModel):
search_space_id: int = Field(
..., description="Search space ID to associate the generation with"
)
image_generation_config_id: int | None = Field(
image_gen_model_id: int | None = Field(
None,
description=(
"Image generation config ID. "
"0 = Auto mode (router), negative = global YAML config, positive = DB config. "
"If not provided, uses the search space's image_generation_config_id preference."
"Image generation model ID. "
"0 = Auto mode, negative = GLOBAL model, positive = BYOK Model row. "
"If not provided, uses the search space's image_gen_model_id preference."
),
)
@ -157,7 +58,7 @@ class ImageGenerationRead(BaseModel):
size: str | None = None
style: str | None = None
response_format: str | None = None
image_generation_config_id: int | None = None
image_gen_model_id: int | None = None
response_data: dict[str, Any] | None = None
error_message: str | None = None
search_space_id: int
@ -203,58 +104,3 @@ class ImageGenerationListRead(BaseModel):
is_success=obj.response_data is not None,
image_count=image_count,
)
# =============================================================================
# Global Image Gen Config (from YAML)
# =============================================================================
class GlobalImageGenConfigRead(BaseModel):
"""
Schema for reading global image generation configs from YAML.
Global configs have negative IDs. API key is hidden.
ID 0 is reserved for Auto mode (LiteLLM Router load balancing).
The ``billing_tier`` field allows the frontend to show a Premium/Free
badge and (more importantly) tells the backend whether to debit the
user's premium credit pool when this config is used. ``"free"`` is
the default for backward compatibility admins must explicitly opt
a global config into ``"premium"``.
"""
id: int = Field(
...,
description="Config ID: 0 for Auto mode, negative for global configs",
)
name: str
description: str | None = None
provider: str
custom_provider: str | None = None
model_name: str
api_base: str | None = None
api_version: str | None = None
litellm_params: dict[str, Any] | None = None
is_global: bool = True
is_auto_mode: bool = False
billing_tier: str = Field(
default="free",
description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).",
)
is_premium: bool = Field(
default=False,
description=(
"Convenience boolean derived server-side from "
"``billing_tier == 'premium'``. The new-chat model selector "
"keys its Free/Premium badge off this field for parity with "
"chat (`GlobalLLMConfigRead.is_premium`)."
),
)
quota_reserve_micros: int | None = Field(
default=None,
description=(
"Optional override for the reservation amount (in micro-USD) used when "
"this image generation is premium. Falls back to "
"QUOTA_DEFAULT_IMAGE_RESERVE_MICROS when omitted."
),
)

View file

@ -0,0 +1,148 @@
import uuid
from datetime import datetime
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from app.db import ConnectionScope, ModelSource
class ModelRead(BaseModel):
id: int
connection_id: int
model_id: str
display_name: str | None = None
source: ModelSource | str
supports_chat: bool | None = None
max_input_tokens: int | None = None
supports_image_input: bool | None = None
supports_tools: bool | None = None
supports_image_generation: bool | None = None
capabilities_override: dict[str, Any] = Field(default_factory=dict)
enabled: bool
billing_tier: str | None = None
catalog: dict[str, Any] = Field(default_factory=dict)
created_at: datetime | None = None
model_config = ConfigDict(from_attributes=True)
class ConnectionRead(BaseModel):
id: int
provider: str
base_url: str | None = None
api_key: str | None = None
extra: dict[str, Any] = Field(default_factory=dict)
scope: ConnectionScope | str
search_space_id: int | None = None
user_id: uuid.UUID | None = None
enabled: bool
has_api_key: bool
models: list[ModelRead] = Field(default_factory=list)
created_at: datetime | None = None
model_config = ConfigDict(from_attributes=True)
class ModelSelection(BaseModel):
model_id: str = Field(..., max_length=255)
display_name: str | None = Field(None, max_length=255)
source: ModelSource | str = ModelSource.DISCOVERED
supports_chat: bool | None = None
max_input_tokens: int | None = None
supports_image_input: bool | None = None
supports_tools: bool | None = None
supports_image_generation: bool | None = None
enabled: bool = False
metadata: dict[str, Any] = Field(default_factory=dict)
class ModelPreviewRead(BaseModel):
model_id: str
display_name: str | None = None
source: ModelSource | str = ModelSource.DISCOVERED
supports_chat: bool | None = None
max_input_tokens: int | None = None
supports_image_input: bool | None = None
supports_tools: bool | None = None
supports_image_generation: bool | None = None
enabled: bool = False
metadata: dict[str, Any] = Field(default_factory=dict)
class ConnectionCreate(BaseModel):
provider: str = Field(..., max_length=100)
base_url: str | None = Field(None, max_length=500)
api_key: str | None = None
extra: dict[str, Any] = Field(default_factory=dict)
scope: ConnectionScope = ConnectionScope.SEARCH_SPACE
search_space_id: int | None = None
enabled: bool = True
models: list[ModelSelection] = Field(default_factory=list)
class ModelTestPreview(ConnectionCreate):
model_id: str = Field(..., max_length=255)
class ConnectionUpdate(BaseModel):
provider: str | None = Field(None, max_length=100)
base_url: str | None = Field(None, max_length=500)
api_key: str | None = None
extra: dict[str, Any] | None = None
enabled: bool | None = None
class ModelCreate(BaseModel):
"""Manually register a model id on a connection.
For providers without a usable ``/models`` endpoint (Perplexity, MiniMax,
Azure deployments, etc.) or to pin a single model from a noisy provider.
"""
model_id: str = Field(..., max_length=255)
display_name: str | None = Field(None, max_length=255)
class ModelUpdate(BaseModel):
display_name: str | None = Field(None, max_length=255)
enabled: bool | None = None
supports_chat: bool | None = None
max_input_tokens: int | None = None
supports_image_input: bool | None = None
supports_tools: bool | None = None
supports_image_generation: bool | None = None
capabilities_override: dict[str, Any] | None = None
class ModelsBulkUpdate(BaseModel):
model_ids: list[int] = Field(..., min_length=1, max_length=1000)
enabled: bool
class ModelProviderRead(BaseModel):
provider: str
transport: str
discovery: str
default_base_url: str | None = None
base_url_required: bool
auth_style: str
local_only: bool = False
class VerifyConnectionResponse(BaseModel):
status: str
ok: bool
message: str = ""
class ModelRolesRead(BaseModel):
chat_model_id: int | None = 0
vision_model_id: int | None = 0
image_gen_model_id: int | None = 0
class ModelRolesUpdate(BaseModel):
chat_model_id: int | None = None
vision_model_id: int | None = None
image_gen_model_id: int | None = None

View file

@ -1,256 +0,0 @@
"""
Pydantic schemas for the NewLLMConfig API.
NewLLMConfig combines model settings with prompt configuration:
- LLM provider, model, API key, etc.
- Configurable system instructions
- Citation toggle
"""
import uuid
from datetime import datetime
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from app.db import LiteLLMProvider
class NewLLMConfigBase(BaseModel):
"""Base schema with common fields for NewLLMConfig."""
name: str = Field(
..., max_length=100, description="User-friendly name for the configuration"
)
description: str | None = Field(
None, max_length=500, description="Optional description"
)
# Model Configuration
provider: LiteLLMProvider = Field(..., description="LiteLLM provider type")
custom_provider: str | None = Field(
None, max_length=100, description="Custom provider name when provider is CUSTOM"
)
model_name: str = Field(
..., max_length=100, description="Model name without provider prefix"
)
api_key: str = Field(..., description="API key for the provider")
api_base: str | None = Field(
None, max_length=500, description="Optional API base URL"
)
litellm_params: dict[str, Any] | None = Field(
default=None, description="Additional LiteLLM parameters"
)
# Prompt Configuration
system_instructions: str = Field(
default="",
description="Custom system instructions. Empty string uses default SURFSENSE_SYSTEM_INSTRUCTIONS.",
)
use_default_system_instructions: bool = Field(
default=True,
description="Whether to use default instructions when system_instructions is empty",
)
citations_enabled: bool = Field(
default=True,
description="Whether to include citation instructions in the system prompt",
)
class NewLLMConfigCreate(NewLLMConfigBase):
"""Schema for creating a new NewLLMConfig."""
search_space_id: int = Field(
..., description="Search space ID to associate the config with"
)
class NewLLMConfigUpdate(BaseModel):
"""Schema for updating an existing NewLLMConfig. All fields are optional."""
name: str | None = Field(None, max_length=100)
description: str | None = Field(None, max_length=500)
# Model Configuration
provider: LiteLLMProvider | None = None
custom_provider: str | None = Field(None, max_length=100)
model_name: str | None = Field(None, max_length=100)
api_key: str | None = None
api_base: str | None = Field(None, max_length=500)
litellm_params: dict[str, Any] | None = None
# Prompt Configuration
system_instructions: str | None = None
use_default_system_instructions: bool | None = None
citations_enabled: bool | None = None
class NewLLMConfigRead(NewLLMConfigBase):
"""Schema for reading a NewLLMConfig (includes id and timestamps)."""
id: int
created_at: datetime
search_space_id: int
user_id: uuid.UUID
# Capability flag derived at the API boundary (no DB column). Default
# True matches the conservative-allow stance — a BYOK row that the
# route forgot to augment is not pre-judged. The streaming-task
# safety net is the only place a False actually blocks a request.
supports_image_input: bool = Field(
default=True,
description=(
"Whether the BYOK chat config can accept image inputs. Derived "
"at the route boundary from LiteLLM's authoritative model map "
"(``litellm.supports_vision``) — there is no DB column. "
"Default True is the conservative-allow stance for unknown / "
"unmapped models."
),
)
model_config = ConfigDict(from_attributes=True)
class NewLLMConfigPublic(BaseModel):
"""
Public schema for NewLLMConfig that hides the API key.
Used when returning configs in list views or to users who shouldn't see keys.
"""
id: int
name: str
description: str | None = None
# Model Configuration (no api_key)
provider: LiteLLMProvider
custom_provider: str | None = None
model_name: str
api_base: str | None = None
litellm_params: dict[str, Any] | None = None
# Prompt Configuration
system_instructions: str
use_default_system_instructions: bool
citations_enabled: bool
created_at: datetime
search_space_id: int
user_id: uuid.UUID
# Capability flag derived at the API boundary (see NewLLMConfigRead).
supports_image_input: bool = Field(
default=True,
description=(
"Whether the BYOK chat config can accept image inputs. Derived "
"at the route boundary from LiteLLM's authoritative model map. "
"Default True is the conservative-allow stance."
),
)
model_config = ConfigDict(from_attributes=True)
class DefaultSystemInstructionsResponse(BaseModel):
"""Response schema for getting default system instructions."""
default_system_instructions: str = Field(
..., description="The default SURFSENSE_SYSTEM_INSTRUCTIONS template"
)
class GlobalNewLLMConfigRead(BaseModel):
"""
Schema for reading global LLM configs from YAML.
Global configs have negative IDs and no search_space_id.
API key is hidden for security.
ID 0 is reserved for Auto mode which uses LiteLLM Router for load balancing.
"""
id: int = Field(
...,
description="Config ID: 0 for Auto mode, negative for global configs",
)
name: str
description: str | None = None
# Model Configuration (no api_key)
provider: str # String because YAML doesn't enforce enum, "AUTO" for Auto mode
custom_provider: str | None = None
model_name: str
api_base: str | None = None
litellm_params: dict[str, Any] | None = None
# Prompt Configuration
system_instructions: str = ""
use_default_system_instructions: bool = True
citations_enabled: bool = True
is_global: bool = True # Always true for global configs
is_auto_mode: bool = False # True only for Auto mode (ID 0)
billing_tier: str = "free"
is_premium: bool = False
anonymous_enabled: bool = False
seo_enabled: bool = False
seo_slug: str | None = None
seo_title: str | None = None
seo_description: str | None = None
quota_reserve_tokens: int | None = None
supports_image_input: bool = Field(
default=True,
description=(
"Whether the model accepts image inputs (multimodal vision). "
"Derived server-side: OpenRouter dynamic configs use "
"``architecture.input_modalities``; YAML / BYOK use LiteLLM's "
"authoritative model map (``litellm.supports_vision``). The "
"new-chat selector hints with a 'No image' badge when this is "
"False and there are pending image attachments. The streaming "
"task fails fast only when LiteLLM *explicitly* marks a model "
"as text-only — unknown / unmapped models default-allow."
),
)
# =============================================================================
# LLM Preferences Schemas (for role assignments)
# =============================================================================
class LLMPreferencesRead(BaseModel):
"""Schema for reading LLM preferences (role assignments) for a search space."""
agent_llm_id: int | None = Field(
None, description="ID of the LLM config to use for agent/chat tasks"
)
image_generation_config_id: int | None = Field(
None, description="ID of the image generation config to use"
)
vision_llm_config_id: int | None = Field(
None,
description="ID of the vision LLM config to use for vision/screenshot analysis",
)
agent_llm: dict[str, Any] | None = Field(
None, description="Full config for agent LLM"
)
image_generation_config: dict[str, Any] | None = Field(
None, description="Full config for image generation"
)
vision_llm_config: dict[str, Any] | None = Field(
None, description="Full config for vision LLM"
)
model_config = ConfigDict(from_attributes=True)
class LLMPreferencesUpdate(BaseModel):
"""Schema for updating LLM preferences."""
agent_llm_id: int | None = Field(
None, description="ID of the LLM config to use for agent/chat tasks"
)
image_generation_config_id: int | None = Field(
None, description="ID of the image generation config to use"
)
vision_llm_config_id: int | None = Field(
None,
description="ID of the vision LLM config to use for vision/screenshot analysis",
)

View file

@ -1,116 +0,0 @@
import uuid
from datetime import datetime
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from app.db import VisionProvider
class VisionLLMConfigBase(BaseModel):
name: str = Field(..., max_length=100)
description: str | None = Field(None, max_length=500)
provider: VisionProvider = Field(...)
custom_provider: str | None = Field(None, max_length=100)
model_name: str = Field(..., max_length=100)
api_key: str = Field(...)
api_base: str | None = Field(None, max_length=500)
api_version: str | None = Field(None, max_length=50)
litellm_params: dict[str, Any] | None = Field(default=None)
class VisionLLMConfigCreate(VisionLLMConfigBase):
search_space_id: int = Field(...)
class VisionLLMConfigUpdate(BaseModel):
name: str | None = Field(None, max_length=100)
description: str | None = Field(None, max_length=500)
provider: VisionProvider | None = None
custom_provider: str | None = Field(None, max_length=100)
model_name: str | None = Field(None, max_length=100)
api_key: str | None = None
api_base: str | None = Field(None, max_length=500)
api_version: str | None = Field(None, max_length=50)
litellm_params: dict[str, Any] | None = None
class VisionLLMConfigRead(VisionLLMConfigBase):
id: int
created_at: datetime
search_space_id: int
user_id: uuid.UUID
model_config = ConfigDict(from_attributes=True)
class VisionLLMConfigPublic(BaseModel):
id: int
name: str
description: str | None = None
provider: VisionProvider
custom_provider: str | None = None
model_name: str
api_base: str | None = None
api_version: str | None = None
litellm_params: dict[str, Any] | None = None
created_at: datetime
search_space_id: int
user_id: uuid.UUID
model_config = ConfigDict(from_attributes=True)
class GlobalVisionLLMConfigRead(BaseModel):
"""Schema for reading global vision LLM configs from YAML.
The ``billing_tier`` field allows the frontend to show a Premium/Free
badge and (more importantly) tells the backend whether to debit the
user's premium credit pool when this config is used. ``"free"`` is
the default for backward compatibility admins must explicitly opt
a global config into ``"premium"``.
"""
id: int = Field(...)
name: str
description: str | None = None
provider: str
custom_provider: str | None = None
model_name: str
api_base: str | None = None
api_version: str | None = None
litellm_params: dict[str, Any] | None = None
is_global: bool = True
is_auto_mode: bool = False
billing_tier: str = Field(
default="free",
description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).",
)
is_premium: bool = Field(
default=False,
description=(
"Convenience boolean derived server-side from "
"``billing_tier == 'premium'``. The new-chat model selector "
"keys its Free/Premium badge off this field for parity with "
"chat (`GlobalLLMConfigRead.is_premium`)."
),
)
quota_reserve_tokens: int | None = Field(
default=None,
description=(
"Optional override for the per-call reservation in *tokens* — "
"converted to micro-USD via the model's input/output prices at "
"reservation time. Falls back to QUOTA_DEFAULT_RESERVE_TOKENS."
),
)
input_cost_per_token: float | None = Field(
default=None,
description=(
"Optional input price in USD/token. Used by pricing_registration to "
"register custom Azure / OpenRouter aliases with LiteLLM at startup."
),
)
output_cost_per_token: float | None = Field(
default=None,
description="Optional output price in USD/token. Pair with input_cost_per_token.",
)

View file

@ -1,13 +1,13 @@
"""Resolve and persist Auto (Fastest) model pins per chat thread.
"""Resolve and persist Auto model pins per chat thread.
Auto (Fastest) is represented by ``agent_llm_id == 0``. For chat threads we
resolve that virtual mode to one concrete global LLM config exactly once and
Auto is represented by ``chat_model_id == 0``. For chat threads we
resolve that virtual mode to one concrete global model exactly once and
persist the chosen config id on ``new_chat_threads.pinned_llm_config_id`` so
subsequent turns are stable.
Single-writer invariant: this module is the only writer of
``NewChatThread.pinned_llm_config_id`` (aside from the bulk clear in
``search_spaces_routes`` when a search space's ``agent_llm_id`` changes).
``model_connections_routes`` when a search space's ``chat_model_id`` changes).
Therefore a non-NULL value unambiguously means "this thread has an
Auto-resolved pin"; no separate source/policy column is needed.
"""
@ -21,26 +21,35 @@ import time
from dataclasses import dataclass
from uuid import UUID
import redis
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.config import config
from app.db import NewChatThread
from app.db import Connection, Model, NewChatThread
from app.services.model_capabilities import has_capability
from app.services.quality_score import _QUALITY_TOP_K
from app.services.token_quota_service import TokenQuotaService
logger = logging.getLogger(__name__)
AUTO_FASTEST_ID = 0
AUTO_FASTEST_MODE = "auto_fastest"
AUTO_MODE_ID = 0
# Stable internal hash namespace for deterministic per-thread selection.
# Do not rename: changing this rebalances Auto's model choice for new pins.
AUTO_PIN_HASH_NAMESPACE = "auto_fastest"
_RUNTIME_COOLDOWN_SECONDS = 600
_HEALTHY_TTL_SECONDS = 45
_RUNTIME_COOLDOWN_REDIS_KEY_PREFIX = "auto:cooldown:llm:"
_REDIS_TIMEOUT_SECONDS = 0.2
# In-memory runtime cooldown map for configs that recently hard-failed at
# provider runtime (e.g. OpenRouter 429 on a pinned free model). This keeps
# the same unhealthy config from being reselected immediately during repair.
_runtime_cooldown_until: dict[int, float] = {}
_runtime_cooldown_lock = threading.Lock()
_runtime_cooldown_redis: redis.Redis | None = None
_runtime_cooldown_redis_lock = threading.Lock()
# Short-TTL "recently healthy" cache for configs that just passed a runtime
# preflight ping. Lets back-to-back turns on the same model skip the probe
@ -61,11 +70,15 @@ def _is_usable_global_config(cfg: dict) -> bool:
return bool(
cfg.get("id") is not None
and cfg.get("model_name")
and cfg.get("provider")
and (cfg.get("provider") or cfg.get("litellm_provider"))
and cfg.get("api_key")
)
def _has_capability(model: dict | Model, capability: str) -> bool:
return has_capability(model, capability)
def _prune_runtime_cooldowns(now_ts: float | None = None) -> None:
now = time.time() if now_ts is None else now_ts
stale = [cid for cid, until in _runtime_cooldown_until.items() if until <= now]
@ -79,6 +92,81 @@ def _is_runtime_cooled_down(config_id: int) -> bool:
return config_id in _runtime_cooldown_until
def _runtime_cooldown_redis_key(config_id: int) -> str:
return f"{_RUNTIME_COOLDOWN_REDIS_KEY_PREFIX}{int(config_id)}"
def _get_runtime_cooldown_redis() -> redis.Redis:
global _runtime_cooldown_redis
if _runtime_cooldown_redis is None:
with _runtime_cooldown_redis_lock:
if _runtime_cooldown_redis is None:
_runtime_cooldown_redis = redis.from_url(
config.REDIS_APP_URL,
decode_responses=True,
socket_connect_timeout=_REDIS_TIMEOUT_SECONDS,
socket_timeout=_REDIS_TIMEOUT_SECONDS,
)
return _runtime_cooldown_redis
def _mark_shared_runtime_cooldown(
config_id: int,
*,
reason: str,
cooldown_seconds: int,
) -> None:
try:
_get_runtime_cooldown_redis().set(
_runtime_cooldown_redis_key(config_id),
reason,
ex=int(cooldown_seconds),
)
except Exception:
logger.warning(
"auto_pin_runtime_cooldown_redis_write_failed config_id=%s",
config_id,
exc_info=True,
)
def _shared_runtime_cooled_down_ids(config_ids: list[int]) -> set[int]:
unique_ids = list(dict.fromkeys(int(cid) for cid in config_ids))
if not unique_ids:
return set()
try:
values = _get_runtime_cooldown_redis().mget(
[_runtime_cooldown_redis_key(cid) for cid in unique_ids]
)
except Exception:
logger.warning(
"auto_pin_runtime_cooldown_redis_read_failed count=%s",
len(unique_ids),
exc_info=True,
)
return set()
return {
cid for cid, value in zip(unique_ids, values, strict=False) if value is not None
}
def _clear_shared_runtime_cooldown(config_id: int | None = None) -> None:
try:
client = _get_runtime_cooldown_redis()
if config_id is not None:
client.delete(_runtime_cooldown_redis_key(config_id))
return
keys = list(client.scan_iter(f"{_RUNTIME_COOLDOWN_REDIS_KEY_PREFIX}*"))
if keys:
client.delete(*keys)
except Exception:
logger.warning(
"auto_pin_runtime_cooldown_redis_clear_failed config_id=%s",
config_id,
exc_info=True,
)
def mark_runtime_cooldown(
config_id: int,
*,
@ -97,6 +185,11 @@ def mark_runtime_cooldown(
with _runtime_cooldown_lock:
_runtime_cooldown_until[int(config_id)] = until
_prune_runtime_cooldowns()
_mark_shared_runtime_cooldown(
int(config_id),
reason=reason,
cooldown_seconds=int(cooldown_seconds),
)
# A cooled cfg can never be "recently healthy"; drop any stale credit so
# the next turn that resolves to it (after cooldown) re-runs preflight.
clear_healthy(int(config_id))
@ -113,8 +206,9 @@ def clear_runtime_cooldown(config_id: int | None = None) -> None:
with _runtime_cooldown_lock:
if config_id is None:
_runtime_cooldown_until.clear()
return
_runtime_cooldown_until.pop(int(config_id), None)
else:
_runtime_cooldown_until.pop(int(config_id), None)
_clear_shared_runtime_cooldown(config_id)
def _prune_healthy(now_ts: float | None = None) -> None:
@ -186,15 +280,20 @@ def _cfg_supports_image_input(cfg: dict) -> bool:
else None
)
return derive_supports_image_input(
provider=cfg.get("provider"),
provider=cfg.get("provider") or cfg.get("litellm_provider"),
model_name=cfg.get("model_name"),
base_model=base_model,
custom_provider=cfg.get("custom_provider"),
)
def _global_candidates(*, requires_image_input: bool = False) -> list[dict]:
"""Return Auto-eligible global cfgs.
def _global_candidates(
*,
capability: str = "chat",
requires_image_input: bool = False,
shared_cooled_down_ids: set[int] | None = None,
) -> list[dict]:
"""Return Auto-eligible global virtual models.
Drops cfgs flagged ``health_gated`` (best non-null OpenRouter uptime
below ``_HEALTH_GATE_UPTIME_PCT``) so chronically broken providers
@ -205,30 +304,167 @@ def _global_candidates(*, requires_image_input: bool = False) -> list[dict]:
filters out configs whose ``supports_image_input`` resolves to False
so a text-only deployment can't be pinned for an image request.
"""
candidates = [
cfg
connection_by_id = {
int(conn.get("id")): conn
for conn in config.GLOBAL_CONNECTIONS
if conn.get("id") is not None
}
config_by_model_name = {
cfg.get("model_name"): cfg
for cfg in config.GLOBAL_LLM_CONFIGS
if _is_usable_global_config(cfg)
and not cfg.get("health_gated")
and not _is_runtime_cooled_down(int(cfg.get("id", 0)))
and (not requires_image_input or _cfg_supports_image_input(cfg))
]
}
candidates: list[dict] = []
shared_cooled_down_ids = shared_cooled_down_ids or set()
for model in config.GLOBAL_MODELS:
model_id = int(model.get("id", 0))
if (
model_id >= 0
or _is_runtime_cooled_down(model_id)
or model_id in shared_cooled_down_ids
):
continue
if not _has_capability(model, capability):
continue
cfg = config_by_model_name.get(model.get("model_id")) or {}
if cfg.get("health_gated"):
continue
if requires_image_input and not _has_capability(model, "vision"):
continue
if requires_image_input and cfg and not _cfg_supports_image_input(cfg):
continue
connection = connection_by_id.get(int(model.get("connection_id", 0)))
if not connection:
continue
catalog = model.get("catalog") or {}
candidates.append(
{
"id": model_id,
"model_id": model.get("model_id"),
"source": "global",
"connection": connection,
"supports_chat": model.get("supports_chat"),
"supports_image_input": model.get("supports_image_input"),
"supports_tools": model.get("supports_tools"),
"supports_image_generation": model.get("supports_image_generation"),
"capabilities_override": model.get("capabilities_override") or {},
"billing_tier": model.get("billing_tier", "free"),
"provider": connection.get("provider"),
"model_name": model.get("model_id"),
"auto_pin_tier": catalog.get("auto_pin_tier")
or cfg.get("auto_pin_tier")
or "A",
"quality_score": catalog.get("quality_score")
or cfg.get("quality_score")
or cfg.get("quality_score_static")
or 50,
}
)
return sorted(candidates, key=lambda c: int(c.get("id", 0)))
async def _db_candidates(
session: AsyncSession,
*,
search_space_id: int,
user_id: str | UUID | None,
capability: str,
requires_image_input: bool = False,
) -> list[dict]:
parsed_user_id = _to_uuid(user_id)
stmt = (
select(Model)
.options(selectinload(Model.connection))
.join(Connection, Model.connection_id == Connection.id)
.where(Model.enabled.is_(True), Connection.enabled.is_(True))
)
result = await session.execute(stmt)
models = result.scalars().all()
shared_cooled_down_ids = _shared_runtime_cooled_down_ids(
[int(model.id) for model in models]
)
candidates: list[dict] = []
for model in models:
conn = model.connection
if not conn:
continue
if conn.search_space_id is not None and conn.search_space_id != search_space_id:
continue
if (
conn.user_id is not None
and parsed_user_id is not None
and conn.user_id != parsed_user_id
):
continue
if conn.user_id is not None and parsed_user_id is None:
continue
if not _has_capability(model, capability):
continue
if requires_image_input and not _has_capability(model, "vision"):
continue
model_id = int(model.id)
if _is_runtime_cooled_down(model_id) or model_id in shared_cooled_down_ids:
continue
catalog = model.catalog or {}
candidates.append(
{
"id": model_id,
"model_id": model.model_id,
"source": "db",
"connection": conn,
"supports_chat": model.supports_chat,
"supports_image_input": model.supports_image_input,
"supports_tools": model.supports_tools,
"supports_image_generation": model.supports_image_generation,
"capabilities_override": model.capabilities_override or {},
"billing_tier": "byok",
"provider": conn.provider,
"model_name": model.model_id,
"auto_pin_tier": catalog.get("auto_pin_tier") or "BYOK",
"quality_score": catalog.get("quality_score") or 75,
}
)
return sorted(candidates, key=lambda c: int(c.get("id", 0)))
async def auto_model_candidates(
session: AsyncSession,
*,
search_space_id: int,
user_id: str | UUID | None,
capability: str,
requires_image_input: bool = False,
exclude_model_ids: set[int] | None = None,
) -> list[dict]:
excluded_ids = {int(mid) for mid in (exclude_model_ids or set())}
global_ids = [
int(model.get("id", 0))
for model in config.GLOBAL_MODELS
if int(model.get("id", 0)) < 0
]
shared_global_cooled_down_ids = _shared_runtime_cooled_down_ids(global_ids)
db_candidates = await _db_candidates(
session,
search_space_id=search_space_id,
user_id=user_id,
capability=capability,
requires_image_input=requires_image_input,
)
candidates = [
*_global_candidates(
capability=capability,
requires_image_input=requires_image_input,
shared_cooled_down_ids=shared_global_cooled_down_ids,
),
*db_candidates,
]
return [c for c in candidates if int(c.get("id", 0)) not in excluded_ids]
def _tier_of(cfg: dict) -> str:
return str(cfg.get("billing_tier", "free")).lower()
def _is_preferred_premium_auto_config(cfg: dict) -> bool:
"""Return True for the operator-preferred premium Auto model."""
return (
_tier_of(cfg) == "premium"
and str(cfg.get("provider", "")).upper() == "AZURE_OPENAI"
and str(cfg.get("model_name", "")).lower() == "gpt-5.4"
)
def _select_pin(eligible: list[dict], thread_id: int) -> tuple[dict, int]:
"""Pick a config with quality-first ranking + deterministic spread.
@ -246,11 +482,16 @@ def _select_pin(eligible: list[dict], thread_id: int) -> tuple[dict, int]:
pool = tier_a if tier_a else eligible
pool = sorted(pool, key=lambda c: -int(c.get("quality_score") or 0))
top_k = pool[:_QUALITY_TOP_K]
digest = hashlib.sha256(f"{AUTO_FASTEST_MODE}:{thread_id}".encode()).digest()
digest = hashlib.sha256(f"{AUTO_PIN_HASH_NAMESPACE}:{thread_id}".encode()).digest()
idx = int.from_bytes(digest[:8], "big") % len(top_k)
return top_k[idx], len(top_k)
def choose_auto_model_candidate(candidates: list[dict], seed_id: int) -> dict:
selected, _ = _select_pin(candidates, seed_id)
return selected
def _to_uuid(user_id: str | UUID | None) -> UUID | None:
if user_id is None:
return None
@ -283,7 +524,7 @@ async def resolve_or_get_pinned_llm_config_id(
exclude_config_ids: set[int] | None = None,
requires_image_input: bool = False,
) -> AutoPinResolution:
"""Resolve Auto (Fastest) to one concrete config id and persist the pin.
"""Resolve Auto to one concrete config id and persist the pin.
For non-auto selections, this function clears any existing pin and returns
the selected id as-is.
@ -315,7 +556,7 @@ async def resolve_or_get_pinned_llm_config_id(
)
# Explicit model selected: clear any stale pin.
if selected_llm_config_id != AUTO_FASTEST_ID:
if selected_llm_config_id != AUTO_MODE_ID:
if thread.pinned_llm_config_id is not None:
thread.pinned_llm_config_id = None
await session.commit()
@ -326,20 +567,21 @@ async def resolve_or_get_pinned_llm_config_id(
)
excluded_ids = {int(cid) for cid in (exclude_config_ids or set())}
candidates = [
c
for c in _global_candidates(requires_image_input=requires_image_input)
if int(c.get("id", 0)) not in excluded_ids
]
candidates = await auto_model_candidates(
session,
search_space_id=search_space_id,
user_id=user_id,
capability="chat",
requires_image_input=requires_image_input,
exclude_model_ids=excluded_ids,
)
if not candidates:
if requires_image_input:
# Distinguish the "no vision-capable cfg" case from generic
# "no usable cfg" so the streaming task can map this to the
# MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT SSE error.
raise ValueError(
"No vision-capable global LLM configs are available for Auto mode"
)
raise ValueError("No usable global LLM configs are available for Auto mode")
raise ValueError("No vision-capable LLM models are available for Auto mode")
raise ValueError("No usable LLM models are available for Auto mode")
candidate_by_id = {int(c["id"]): c for c in candidates}
# Reuse an existing valid pin without re-checking current quota (no silent
@ -379,24 +621,13 @@ async def resolve_or_get_pinned_llm_config_id(
# log that explicitly so operators can correlate the re-pin with
# the user's image attachment instead of suspecting a cooldown.
if requires_image_input:
try:
pinned_global = next(
c
for c in config.GLOBAL_LLM_CONFIGS
if int(c.get("id", 0)) == int(pinned_id)
)
except StopIteration:
pinned_global = None
if pinned_global is not None and not _cfg_supports_image_input(
pinned_global
):
logger.info(
"auto_pin_repinned_for_image thread_id=%s search_space_id=%s "
"previous_config_id=%s",
thread_id,
search_space_id,
pinned_id,
)
logger.info(
"auto_pin_repinned_for_image thread_id=%s search_space_id=%s "
"previous_config_id=%s",
thread_id,
search_space_id,
pinned_id,
)
logger.info(
"auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s",
thread_id,
@ -407,12 +638,10 @@ async def resolve_or_get_pinned_llm_config_id(
premium_eligible = (
False if force_repin_free else await _is_premium_eligible(session, user_id)
)
byok_candidates = [c for c in candidates if _tier_of(c) == "byok"]
if premium_eligible:
premium_candidates = [c for c in candidates if _tier_of(c) == "premium"]
preferred_premium = [
c for c in premium_candidates if _is_preferred_premium_auto_config(c)
]
eligible = preferred_premium or premium_candidates
eligible = premium_candidates or byok_candidates
else:
eligible = [c for c in candidates if _tier_of(c) != "premium"]

View file

@ -445,15 +445,15 @@ async def _resolve_agent_billing_for_search_space(
thread_id: int | None = None,
) -> tuple[UUID, str, str]:
"""Resolve ``(owner_user_id, billing_tier, base_model)`` for the search-space
agent LLM.
chat model.
Used by Celery tasks (podcast generation, video presentation) to bill the
search-space owner's premium credit pool when the agent LLM is premium.
search-space owner's premium credit pool when the chat model is premium.
Resolution rules mirror chat at ``stream_new_chat.py:2294-2351``:
Resolution rules mirror the chat model role resolver:
- Search space not found / no ``agent_llm_id``: raise ``ValueError``.
- **Auto mode** (``id == AUTO_FASTEST_ID == 0``):
- Search space not found / no ``chat_model_id``: raise ``ValueError``.
- **Auto mode** (``id == AUTO_MODE_ID == 0``):
* ``thread_id`` is set: delegate to
``resolve_or_get_pinned_llm_config_id`` (the same call chat uses) and
recurse into the resolved id. Reuses chat's existing pin if present
@ -469,9 +469,8 @@ async def _resolve_agent_billing_for_search_space(
(defaults to ``"free"`` via ``app/config/__init__.py:52`` setdefault),
``base_model = litellm_params.get("base_model") or model_name``
NOT provider-prefixed, matching chat's cost-map lookup convention.
- **Positive id** (user BYOK ``NewLLMConfig``): always free (matches
``AgentConfig.from_new_llm_config`` which hard-codes ``billing_tier="free"``);
``base_model`` from ``litellm_params`` or ``model_name``.
- **Positive id** (user BYOK ``Model``): always free; ``base_model`` from
the model catalog override or the upstream ``model_id``.
Note on imports: ``llm_service``, ``auto_model_pin_service``, and
``llm_router_service`` are imported lazily inside the function body to
@ -480,8 +479,9 @@ async def _resolve_agent_billing_for_search_space(
``billable_calls.py``'s module load path.
"""
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from app.db import NewLLMConfig, SearchSpace
from app.db import Model, SearchSpace
result = await session.execute(
select(SearchSpace).where(SearchSpace.id == search_space_id)
@ -490,20 +490,20 @@ async def _resolve_agent_billing_for_search_space(
if search_space is None:
raise ValueError(f"Search space {search_space_id} not found")
agent_llm_id = search_space.agent_llm_id
if agent_llm_id is None:
chat_model_id = search_space.chat_model_id
if chat_model_id is None:
raise ValueError(
f"Search space {search_space_id} has no agent_llm_id configured"
f"Search space {search_space_id} has no chat_model_id configured"
)
owner_user_id: UUID = search_space.user_id
from app.services.auto_model_pin_service import (
AUTO_FASTEST_ID,
AUTO_MODE_ID,
resolve_or_get_pinned_llm_config_id,
)
if agent_llm_id == AUTO_FASTEST_ID:
if chat_model_id == AUTO_MODE_ID:
if thread_id is None:
return owner_user_id, "free", "auto"
try:
@ -512,7 +512,7 @@ async def _resolve_agent_billing_for_search_space(
thread_id=thread_id,
search_space_id=search_space_id,
user_id=str(owner_user_id),
selected_llm_config_id=AUTO_FASTEST_ID,
selected_llm_config_id=AUTO_MODE_ID,
)
except ValueError:
logger.warning(
@ -523,28 +523,35 @@ async def _resolve_agent_billing_for_search_space(
exc_info=True,
)
return owner_user_id, "free", "auto"
agent_llm_id = resolution.resolved_llm_config_id
chat_model_id = resolution.resolved_llm_config_id
if agent_llm_id < 0:
if chat_model_id < 0:
from app.services.llm_service import get_global_llm_config
cfg = get_global_llm_config(agent_llm_id) or {}
cfg = get_global_llm_config(chat_model_id) or {}
billing_tier = str(cfg.get("billing_tier", "free")).lower()
litellm_params = cfg.get("litellm_params") or {}
base_model = litellm_params.get("base_model") or cfg.get("model_name") or ""
return owner_user_id, billing_tier, base_model
nlc_result = await session.execute(
select(NewLLMConfig).where(
NewLLMConfig.id == agent_llm_id,
NewLLMConfig.search_space_id == search_space_id,
)
model_result = await session.execute(
select(Model)
.options(selectinload(Model.connection))
.where(Model.id == chat_model_id, Model.enabled.is_(True))
)
nlc = nlc_result.scalars().first()
model = model_result.scalars().first()
base_model = ""
if nlc is not None:
litellm_params = nlc.litellm_params or {}
base_model = litellm_params.get("base_model") or nlc.model_name or ""
if (
model is not None
and model.connection is not None
and model.connection.enabled
and (
model.connection.search_space_id in (None, search_space_id)
and model.connection.user_id in (None, owner_user_id)
)
):
catalog = model.catalog or {}
base_model = catalog.get("base_model") or model.model_id or ""
return owner_user_id, "free", base_model

View file

@ -0,0 +1,128 @@
"""Materialize server-owned GLOBAL YAML configs as virtual connections/models."""
from __future__ import annotations
from typing import Any
from app.services.model_resolver import native_connection_from_config
def _base_model(config: dict[str, Any]) -> str | None:
litellm_params = config.get("litellm_params") or {}
if isinstance(litellm_params, dict):
return litellm_params.get("base_model")
return None
def _connection_key(conn: dict[str, Any]) -> tuple[Any, ...]:
# Deliberately includes api_key because two operator-owned credentials for
# the same provider/base can have different quota/rate limits upstream.
return (
conn.get("provider"),
conn.get("base_url"),
conn.get("api_key"),
_freeze(conn.get("extra") or {}),
)
def _freeze(value: Any) -> Any:
if isinstance(value, dict):
return tuple(sorted((key, _freeze(val)) for key, val in value.items()))
if isinstance(value, list):
return tuple(_freeze(item) for item in value)
return value
def _catalog_metadata(config: dict[str, Any]) -> dict[str, Any]:
return {
"billing_tier": config.get("billing_tier", "free"),
"quota_reserve_tokens": config.get("quota_reserve_tokens"),
"rpm": config.get("rpm"),
"tpm": config.get("tpm"),
"anonymous_enabled": config.get("anonymous_enabled", False),
"seo_enabled": config.get("seo_enabled", False),
"seo_slug": config.get("seo_slug"),
"input_cost_per_token": (config.get("litellm_params") or {}).get(
"input_cost_per_token"
)
if isinstance(config.get("litellm_params"), dict)
else None,
"output_cost_per_token": (config.get("litellm_params") or {}).get(
"output_cost_per_token"
)
if isinstance(config.get("litellm_params"), dict)
else None,
"is_planner": config.get("is_planner", False),
"base_model": _base_model(config),
"router_pool_eligible": config.get("router_pool_eligible", True),
}
def materialize_global_model_catalog(
*,
chat_configs: list[dict[str, Any]],
image_configs: list[dict[str, Any]],
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
connections: list[dict[str, Any]] = []
models: list[dict[str, Any]] = []
connection_id_by_key: dict[tuple[Any, ...], int] = {}
next_connection_id = -1
def add_config(config: dict[str, Any], role: str) -> None:
nonlocal next_connection_id
if not config.get("id") or not config.get("model_name"):
return
conn = native_connection_from_config(config)
conn["scope"] = "GLOBAL"
conn["enabled"] = True
key = _connection_key(conn)
connection_id = connection_id_by_key.get(key)
if connection_id is None:
connection_id = next_connection_id
next_connection_id -= 1
connection_id_by_key[key] = connection_id
connections.append(
{
"id": connection_id,
**conn,
}
)
model_id = int(config["id"])
models.append(
{
"id": model_id,
"connection_id": connection_id,
"model_id": config["model_name"],
"display_name": config.get("name") or config["model_name"],
"source": "MANUAL",
"supports_chat": role == "chat",
"max_input_tokens": config.get("max_input_tokens"),
"supports_image_input": (
role == "chat" and bool(config.get("supports_image_input"))
),
"supports_tools": bool(config.get("supports_tools", False)),
"supports_image_generation": role == "image_gen",
"capabilities_override": {},
"enabled": True,
"billing_tier": config.get("billing_tier", "free"),
"catalog": _catalog_metadata(config),
"role": role,
}
)
for cfg in chat_configs:
if cfg.get("is_auto_mode"):
continue
add_config(cfg, "chat")
for cfg in image_configs:
if cfg.get("is_auto_mode"):
continue
add_config(cfg, "image_gen")
# Each virtual connection is server-only. Callers that serialize these
# must strip api_key before returning data to clients.
return connections, models
__all__ = ["materialize_global_model_catalog"]

View file

@ -20,28 +20,13 @@ from typing import Any
from litellm import Router
from litellm.utils import ImageResponse
from app.services.provider_api_base import resolve_api_base
from app.services.model_resolver import native_connection_from_config, to_litellm
logger = logging.getLogger(__name__)
# Special ID for Auto mode - uses router for load balancing
IMAGE_GEN_AUTO_MODE_ID = 0
# Provider mapping for LiteLLM model string construction.
# Only includes providers that support image generation.
# See: https://docs.litellm.ai/docs/image_generation#supported-providers
IMAGE_GEN_PROVIDER_MAP = {
"OPENAI": "openai",
"AZURE_OPENAI": "azure",
"GOOGLE": "gemini", # Google AI Studio
"VERTEX_AI": "vertex_ai",
"BEDROCK": "bedrock", # AWS Bedrock
"RECRAFT": "recraft",
"OPENROUTER": "openrouter",
"XINFERENCE": "xinference",
"NSCALE": "nscale",
}
class ImageGenRouterService:
"""
@ -153,38 +138,11 @@ class ImageGenRouterService:
if not config.get("model_name") or not config.get("api_key"):
return None
# Build model string
provider = config.get("provider", "").upper()
if config.get("custom_provider"):
provider_prefix = config["custom_provider"]
else:
provider_prefix = IMAGE_GEN_PROVIDER_MAP.get(provider, provider.lower())
model_string = f"{provider_prefix}/{config['model_name']}"
# Build litellm params
litellm_params: dict[str, Any] = {
"model": model_string,
"api_key": config.get("api_key"),
}
# Resolve ``api_base`` so deployments don't silently inherit
# ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE`` and 404 against
# the wrong provider (see ``provider_api_base`` docstring).
api_base = resolve_api_base(
provider=provider,
provider_prefix=provider_prefix,
config_api_base=config.get("api_base"),
model_string, resolved_kwargs = to_litellm(
native_connection_from_config(config),
config["model_name"],
)
if api_base:
litellm_params["api_base"] = api_base
# Add api_version (required for Azure)
if config.get("api_version"):
litellm_params["api_version"] = config["api_version"]
# Add any additional litellm parameters
if config.get("litellm_params"):
litellm_params.update(config["litellm_params"])
litellm_params: dict[str, Any] = {"model": model_string, **resolved_kwargs}
# All configs use same alias "auto" for unified routing
deployment: dict[str, Any] = {

View file

@ -0,0 +1,257 @@
"""Normalize provider/LLM exceptions into low-cardinality product categories."""
from __future__ import annotations
import json
from dataclasses import dataclass
from enum import StrEnum
from typing import Any
class LLMErrorCategory(StrEnum):
RATE_LIMITED = "rate_limited"
TIMEOUT = "timeout"
PROVIDER_UNAVAILABLE = "provider_unavailable"
BAD_GATEWAY = "bad_gateway"
CONNECTION_FAILED = "connection_failed"
AUTH_FAILED = "auth_failed"
PERMISSION_DENIED = "permission_denied"
MODEL_NOT_FOUND = "model_not_found"
BAD_REQUEST = "bad_request"
CONTEXT_LIMIT = "context_limit"
RESPONSE_INVALID = "response_invalid"
SERVER_ERROR = "server_error"
UNKNOWN = "unknown"
@dataclass(frozen=True)
class LLMErrorAdaptation:
category: LLMErrorCategory
retryable: bool
user_message: str
provider_status_code: int | None = None
provider_error_type: str | None = None
_CATEGORY_MESSAGES: dict[LLMErrorCategory, str] = {
LLMErrorCategory.RATE_LIMITED: "LLM rate limit exceeded. Will retry on next sync.",
LLMErrorCategory.TIMEOUT: "LLM request timed out. Will retry on next sync.",
LLMErrorCategory.PROVIDER_UNAVAILABLE: "LLM service temporarily unavailable. Will retry on next sync.",
LLMErrorCategory.BAD_GATEWAY: "LLM gateway error. Will retry on next sync.",
LLMErrorCategory.CONNECTION_FAILED: "Could not reach the LLM service. Check network connectivity.",
LLMErrorCategory.AUTH_FAILED: "LLM authentication failed. Check your API key.",
LLMErrorCategory.PERMISSION_DENIED: "LLM request denied. Check your account permissions.",
LLMErrorCategory.MODEL_NOT_FOUND: "Model not found. Check your model configuration.",
LLMErrorCategory.BAD_REQUEST: "LLM rejected the request. Document content may be invalid.",
LLMErrorCategory.CONTEXT_LIMIT: "Document exceeds the LLM context window even after optimization.",
LLMErrorCategory.RESPONSE_INVALID: "LLM returned an invalid response.",
LLMErrorCategory.SERVER_ERROR: "LLM internal server error. Will retry on next sync.",
LLMErrorCategory.UNKNOWN: "Something went wrong when calling the LLM.",
}
_RETRYABLE_CATEGORIES = {
LLMErrorCategory.RATE_LIMITED,
LLMErrorCategory.TIMEOUT,
LLMErrorCategory.PROVIDER_UNAVAILABLE,
LLMErrorCategory.BAD_GATEWAY,
LLMErrorCategory.CONNECTION_FAILED,
LLMErrorCategory.SERVER_ERROR,
}
_CLASS_NAME_MAP: tuple[tuple[LLMErrorCategory, tuple[str, ...]], ...] = (
(
LLMErrorCategory.RATE_LIMITED,
("RateLimitError", "TooManyRequests", "TooManyRequestsError"),
),
(LLMErrorCategory.TIMEOUT, ("Timeout", "APITimeoutError", "TimeoutException")),
(
LLMErrorCategory.PROVIDER_UNAVAILABLE,
("ServiceUnavailableError", "ServiceUnavailable"),
),
(
LLMErrorCategory.BAD_GATEWAY,
("BadGatewayError", "GatewayTimeoutError"),
),
(
LLMErrorCategory.CONNECTION_FAILED,
("APIConnectionError", "ConnectError", "ConnectTimeout", "ReadTimeout"),
),
(
LLMErrorCategory.AUTH_FAILED,
("AuthenticationError", "InvalidApiKey", "InvalidAPIKey", "InvalidApiKeyError"),
),
(LLMErrorCategory.PERMISSION_DENIED, ("PermissionDeniedError", "ForbiddenError")),
(LLMErrorCategory.MODEL_NOT_FOUND, ("NotFoundError", "ModelNotFoundError")),
(
LLMErrorCategory.CONTEXT_LIMIT,
("ContextWindowExceeded", "ContextOverflow", "ContextLimit"),
),
(
LLMErrorCategory.RESPONSE_INVALID,
("APIResponseValidationError", "ResponseValidationError"),
),
(
LLMErrorCategory.BAD_REQUEST,
("BadRequestError", "InvalidRequestError", "UnprocessableEntityError"),
),
(LLMErrorCategory.SERVER_ERROR, ("InternalServerError",)),
)
def _parse_error_payload(message: str) -> dict[str, Any] | None:
candidates = [message]
first_brace_idx = message.find("{")
if first_brace_idx >= 0:
candidates.append(message[first_brace_idx:])
for candidate in candidates:
try:
parsed = json.loads(candidate)
if isinstance(parsed, dict):
return parsed
except Exception:
continue
return None
def _class_names(exc: BaseException) -> tuple[str, ...]:
return tuple(cls.__name__ for cls in type(exc).__mro__)
def _category_from_class_name(exc: BaseException) -> LLMErrorCategory | None:
names = _class_names(exc)
for category, hints in _CLASS_NAME_MAP:
if any(any(hint in name for hint in hints) for name in names):
return category
return None
def _extract_provider_status_code(parsed: dict[str, Any] | None) -> int | None:
if not isinstance(parsed, dict):
return None
candidates: list[Any] = [parsed.get("code"), parsed.get("status")]
nested = parsed.get("error")
if isinstance(nested, dict):
candidates.extend([nested.get("code"), nested.get("status")])
for value in candidates:
try:
if value is None:
continue
return int(value)
except Exception:
continue
return None
def _extract_provider_error_type(parsed: dict[str, Any] | None) -> str | None:
if not isinstance(parsed, dict):
return None
candidates: list[Any] = [parsed.get("type")]
nested = parsed.get("error")
if isinstance(nested, dict):
candidates.append(nested.get("type"))
for value in candidates:
if isinstance(value, str) and value:
return value
return None
def _category_from_provider_payload(
status_code: int | None,
provider_error_type: str | None,
) -> LLMErrorCategory | None:
if status_code == 429:
return LLMErrorCategory.RATE_LIMITED
if status_code == 401:
return LLMErrorCategory.AUTH_FAILED
if status_code == 403:
return LLMErrorCategory.PERMISSION_DENIED
if status_code == 404:
return LLMErrorCategory.MODEL_NOT_FOUND
if status_code in (400, 422):
return LLMErrorCategory.BAD_REQUEST
if status_code in (502, 504):
return LLMErrorCategory.BAD_GATEWAY
if status_code == 503:
return LLMErrorCategory.PROVIDER_UNAVAILABLE
if status_code is not None and status_code >= 500:
return LLMErrorCategory.SERVER_ERROR
normalized_type = (provider_error_type or "").lower()
if normalized_type == "rate_limit_error":
return LLMErrorCategory.RATE_LIMITED
if normalized_type in {
"authentication_error",
"invalid_api_key",
"invalid_api_key_error",
}:
return LLMErrorCategory.AUTH_FAILED
if normalized_type in {"permission_denied", "forbidden"}:
return LLMErrorCategory.PERMISSION_DENIED
if normalized_type in {"not_found_error", "model_not_found"}:
return LLMErrorCategory.MODEL_NOT_FOUND
if normalized_type in {"context_length_exceeded", "context_window_exceeded"}:
return LLMErrorCategory.CONTEXT_LIMIT
return None
def _category_from_message(raw: str) -> LLMErrorCategory | None:
lowered = raw.lower()
if any(
hint in lowered
for hint in ("rate limit", "rate-limited", "temporarily rate-limited")
):
return LLMErrorCategory.RATE_LIMITED
if any(
hint in lowered
for hint in (
"invalid api key",
"invalid_api_key",
"authentication",
"unauthorized",
"user not found",
"api key is expired",
"expired api key",
)
):
return LLMErrorCategory.AUTH_FAILED
if "forbidden" in lowered or "permission denied" in lowered:
return LLMErrorCategory.PERMISSION_DENIED
if "model not found" in lowered:
return LLMErrorCategory.MODEL_NOT_FOUND
if any(
hint in lowered
for hint in (
"context length",
"context window",
"maximum context",
"too many tokens",
)
):
return LLMErrorCategory.CONTEXT_LIMIT
return None
def adapt_llm_exception(exc: BaseException) -> LLMErrorAdaptation:
raw = str(exc)
parsed = _parse_error_payload(raw)
status_code = _extract_provider_status_code(parsed)
provider_error_type = _extract_provider_error_type(parsed)
category = (
_category_from_provider_payload(status_code, provider_error_type)
or _category_from_message(raw)
or _category_from_class_name(exc)
or LLMErrorCategory.UNKNOWN
)
return LLMErrorAdaptation(
category=category,
retryable=category in _RETRYABLE_CATEGORIES,
user_message=_CATEGORY_MESSAGES[category],
provider_status_code=status_code,
provider_error_type=provider_error_type,
)
def llm_error_message(exc: BaseException) -> str:
return adapt_llm_exception(exc).user_message

View file

@ -30,6 +30,7 @@ from litellm.exceptions import (
)
from pydantic import Field
from app.services.model_resolver import native_connection_from_config, to_litellm
from app.utils.perf import get_perf_logger
litellm.json_logs = False
@ -96,53 +97,6 @@ def _sanitize_content(content: Any) -> Any:
# Special ID for Auto mode - uses router for load balancing
AUTO_MODE_ID = 0
# Provider mapping for LiteLLM model string construction
PROVIDER_MAP = {
"OPENAI": "openai",
"ANTHROPIC": "anthropic",
"GROQ": "groq",
"COHERE": "cohere",
"GOOGLE": "gemini",
"OLLAMA": "ollama_chat",
"MISTRAL": "mistral",
"AZURE_OPENAI": "azure",
"OPENROUTER": "openrouter",
"COMETAPI": "cometapi",
"XAI": "xai",
"BEDROCK": "bedrock",
"AWS_BEDROCK": "bedrock", # Legacy support
"VERTEX_AI": "vertex_ai",
"TOGETHER_AI": "together_ai",
"FIREWORKS_AI": "fireworks_ai",
"REPLICATE": "replicate",
"PERPLEXITY": "perplexity",
"ANYSCALE": "anyscale",
"DEEPINFRA": "deepinfra",
"CEREBRAS": "cerebras",
"SAMBANOVA": "sambanova",
"AI21": "ai21",
"CLOUDFLARE": "cloudflare",
"DATABRICKS": "databricks",
"DEEPSEEK": "openai",
"ALIBABA_QWEN": "openai",
"MOONSHOT": "openai",
"ZHIPU": "openai",
"GITHUB_MODELS": "github",
"HUGGINGFACE": "huggingface",
"MINIMAX": "openai",
"CUSTOM": "custom",
}
# ``PROVIDER_DEFAULT_API_BASE`` and ``PROVIDER_KEY_DEFAULT_API_BASE`` were
# hoisted to ``app.services.provider_api_base`` so vision and image-gen
# call sites can share the exact same defense (OpenRouter / Groq / etc.
# 404-ing against an inherited Azure endpoint). Re-exported here for
# backward compatibility with any external import.
from app.services.provider_api_base import ( # noqa: E402
resolve_api_base,
)
class LLMRouterService:
"""
@ -420,38 +374,11 @@ class LLMRouterService:
if not config.get("model_name") or not config.get("api_key"):
return None
# Build model string
provider = config.get("provider", "").upper()
if config.get("custom_provider"):
provider_prefix = config["custom_provider"]
model_string = f"{provider_prefix}/{config['model_name']}"
else:
provider_prefix = PROVIDER_MAP.get(provider, provider.lower())
model_string = f"{provider_prefix}/{config['model_name']}"
# Build litellm params
litellm_params = {
"model": model_string,
"api_key": config.get("api_key"),
}
# Resolve ``api_base``. Config value wins; otherwise apply a
# provider-aware default so the deployment does not silently
# inherit unrelated env vars (e.g. ``AZURE_API_BASE``) and route
# requests to the wrong endpoint. See ``provider_api_base``
# docstring for the motivating bug (OpenRouter models 404-ing
# against an Azure endpoint).
api_base = resolve_api_base(
provider=provider,
provider_prefix=provider_prefix,
config_api_base=config.get("api_base"),
model_string, resolved_kwargs = to_litellm(
native_connection_from_config(config),
config["model_name"],
)
if api_base:
litellm_params["api_base"] = api_base
# Add any additional litellm parameters
if config.get("litellm_params"):
litellm_params.update(config["litellm_params"])
litellm_params = {"model": model_string, **resolved_kwargs}
# Extract rate limits if provided
deployment = {

View file

@ -6,17 +6,21 @@ from langchain_core.messages import HumanMessage
from langchain_litellm import ChatLiteLLM
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from app.config import config
from app.db import NewLLMConfig, SearchSpace
from app.db import Model, SearchSpace
from app.services.auto_model_pin_service import (
auto_model_candidates,
choose_auto_model_candidate,
)
from app.services.llm_router_service import (
AUTO_MODE_ID,
ChatLiteLLMRouter,
LLMRouterService,
get_auto_mode_llm,
is_auto_mode,
)
from app.services.provider_api_base import resolve_api_base
from app.services.model_capabilities import has_capability
from app.services.model_resolver import native_connection_from_config, to_litellm
from app.services.token_tracking_service import token_tracker
# Configure litellm to automatically drop unsupported parameters
@ -66,6 +70,29 @@ def _is_interactive_auth_provider(
return False
def _legacy_config_connection(
*,
provider: str,
model_name: str,
api_key: str | None,
api_base: str | None,
custom_provider: str | None = None,
litellm_params: dict | None = None,
api_version: str | None = None,
) -> tuple[str, dict]:
cfg = {
"provider": provider.lower(),
"model_name": model_name,
"api_key": api_key,
"api_base": api_base,
"custom_provider": custom_provider,
"api_version": api_version,
"litellm_params": litellm_params or {},
}
conn = native_connection_from_config(cfg)
return to_litellm(conn, model_name)
class LLMRole:
AGENT = "agent" # For agent/chat operations
@ -73,26 +100,16 @@ class LLMRole:
def get_global_llm_config(llm_config_id: int) -> dict | None:
"""
Get a global LLM configuration by ID.
Global configs have negative IDs. ID 0 is reserved for Auto mode.
Global configs have negative IDs. Auto mode (ID 0) is resolved through the
model-candidate pipeline, not this legacy config lookup.
Args:
llm_config_id: The ID of the global config (should be negative or 0 for Auto)
llm_config_id: The ID of the global config (must be negative)
Returns:
dict: Global config dictionary or None if not found
"""
# Auto mode (ID 0) is handled separately via the router
if llm_config_id == AUTO_MODE_ID:
return {
"id": AUTO_MODE_ID,
"name": "Auto (Fastest)",
"description": "Automatically routes requests across available LLM providers for optimal performance and rate limit handling",
"provider": "AUTO",
"model_name": "auto",
"is_auto_mode": True,
}
if llm_config_id > 0:
if llm_config_id >= 0:
return None
for cfg in config.GLOBAL_LLM_CONFIGS:
@ -102,6 +119,55 @@ def get_global_llm_config(llm_config_id: int) -> dict | None:
return None
def get_global_model(model_id: int) -> dict | None:
return next((m for m in config.GLOBAL_MODELS if m.get("id") == model_id), None)
def get_global_connection(connection_id: int) -> dict | None:
return next(
(c for c in config.GLOBAL_CONNECTIONS if c.get("id") == connection_id),
None,
)
def _has_capability(model: dict | Model, capability: str) -> bool:
return has_capability(model, capability)
def _chat_litellm_from_resolved(
*,
conn: dict | object,
model_id: str,
disable_streaming: bool = False,
) -> tuple[str, dict]:
model_string, resolved_kwargs = to_litellm(conn, model_id)
litellm_kwargs = {"model": model_string, **resolved_kwargs}
if disable_streaming:
litellm_kwargs["disable_streaming"] = True
return model_string, litellm_kwargs
async def _get_db_model(
session: AsyncSession,
model_id: int,
search_space: SearchSpace,
) -> Model | None:
result = await session.execute(
select(Model)
.options(selectinload(Model.connection))
.where(Model.id == model_id, Model.enabled.is_(True))
)
model = result.scalars().first()
if not model or not model.connection or not model.connection.enabled:
return None
conn = model.connection
if conn.search_space_id and conn.search_space_id != search_space.id:
return None
if conn.user_id and conn.user_id != search_space.user_id:
return None
return model
async def validate_llm_config(
provider: str,
model_name: str,
@ -146,62 +212,15 @@ async def validate_llm_config(
return False, msg
try:
# Build the model string for litellm
if custom_provider:
model_string = f"{custom_provider}/{model_name}"
else:
# Map provider enum to litellm format
provider_map = {
"OPENAI": "openai",
"ANTHROPIC": "anthropic",
"GROQ": "groq",
"COHERE": "cohere",
"GOOGLE": "gemini",
"OLLAMA": "ollama_chat",
"MISTRAL": "mistral",
"AZURE_OPENAI": "azure",
"OPENROUTER": "openrouter",
"COMETAPI": "cometapi",
"XAI": "xai",
"BEDROCK": "bedrock",
"AWS_BEDROCK": "bedrock", # Legacy support (backward compatibility)
"VERTEX_AI": "vertex_ai",
"TOGETHER_AI": "together_ai",
"FIREWORKS_AI": "fireworks_ai",
"REPLICATE": "replicate",
"PERPLEXITY": "perplexity",
"ANYSCALE": "anyscale",
"DEEPINFRA": "deepinfra",
"CEREBRAS": "cerebras",
"SAMBANOVA": "sambanova",
"AI21": "ai21",
"CLOUDFLARE": "cloudflare",
"DATABRICKS": "databricks",
# Chinese LLM providers
"DEEPSEEK": "openai",
"ALIBABA_QWEN": "openai",
"MOONSHOT": "openai",
"ZHIPU": "openai", # GLM needs special handling
"MINIMAX": "openai",
"GITHUB_MODELS": "github",
}
provider_prefix = provider_map.get(provider, provider.lower())
model_string = f"{provider_prefix}/{model_name}"
# Create ChatLiteLLM instance
litellm_kwargs = {
"model": model_string,
"api_key": api_key,
"timeout": 30, # Set a timeout for validation
}
# Add optional parameters
if api_base:
litellm_kwargs["api_base"] = api_base
# Add any additional litellm parameters
if litellm_params:
litellm_kwargs.update(litellm_params)
model_string, resolved_kwargs = _legacy_config_connection(
provider=provider,
model_name=model_name,
api_key=api_key,
api_base=api_base,
custom_provider=custom_provider,
litellm_params=litellm_params,
)
litellm_kwargs = {"model": model_string, **resolved_kwargs, "timeout": 30}
from app.agents.chat.runtime.llm_config import (
SanitizedChatLiteLLM,
@ -283,9 +302,9 @@ async def get_search_space_llm_instance(
logger.error(f"Search space {search_space_id} not found")
return None
# Get the appropriate LLM config ID based on role
# Get the appropriate model binding ID based on role
if role == LLMRole.AGENT:
llm_config_id = search_space.agent_llm_id
llm_config_id = search_space.chat_model_id
else:
logger.error(f"Invalid LLM role: {role}")
return None
@ -294,88 +313,42 @@ async def get_search_space_llm_instance(
logger.error(f"No {role} LLM configured for search space {search_space_id}")
return None
# Check for Auto mode (ID 0) - use router for load balancing
# Auto mode resolves to one concrete global or BYOK model from the
# unified model-connections catalog.
if is_auto_mode(llm_config_id):
if not LLMRouterService.is_initialized():
logger.error(
"Auto mode requested but LLM Router not initialized. "
"Ensure global_llm_config.yaml exists with valid configs."
)
candidates = await auto_model_candidates(
session,
search_space_id=search_space_id,
user_id=search_space.user_id,
capability="chat",
)
if not candidates:
logger.error("No chat-capable models available for Auto mode")
return None
llm_config_id = int(
choose_auto_model_candidate(candidates, search_space_id)["id"]
)
try:
logger.debug(
f"Using Auto mode (LLM Router) for search space {search_space_id}, role {role}"
)
return get_auto_mode_llm(streaming=not disable_streaming)
except Exception as e:
logger.error(f"Failed to create ChatLiteLLMRouter: {e}")
return None
# Check if this is a global config (negative ID)
# Check if this is a global virtual model (negative ID)
if llm_config_id < 0:
global_config = get_global_llm_config(llm_config_id)
if not global_config:
logger.error(f"Global LLM config {llm_config_id} not found")
global_model = get_global_model(llm_config_id)
if not global_model or not _has_capability(global_model, "chat"):
logger.error(f"Global chat model {llm_config_id} not found")
return None
global_connection = get_global_connection(global_model["connection_id"])
if not global_connection:
logger.error(
"Global connection %s not found for model %s",
global_model["connection_id"],
llm_config_id,
)
return None
# Build model string for global config
if global_config.get("custom_provider"):
model_string = (
f"{global_config['custom_provider']}/{global_config['model_name']}"
)
else:
provider_map = {
"OPENAI": "openai",
"ANTHROPIC": "anthropic",
"GROQ": "groq",
"COHERE": "cohere",
"GOOGLE": "gemini",
"OLLAMA": "ollama_chat",
"MISTRAL": "mistral",
"AZURE_OPENAI": "azure",
"OPENROUTER": "openrouter",
"COMETAPI": "cometapi",
"XAI": "xai",
"BEDROCK": "bedrock",
"AWS_BEDROCK": "bedrock",
"VERTEX_AI": "vertex_ai",
"TOGETHER_AI": "together_ai",
"FIREWORKS_AI": "fireworks_ai",
"REPLICATE": "replicate",
"PERPLEXITY": "perplexity",
"ANYSCALE": "anyscale",
"DEEPINFRA": "deepinfra",
"CEREBRAS": "cerebras",
"SAMBANOVA": "sambanova",
"AI21": "ai21",
"CLOUDFLARE": "cloudflare",
"DATABRICKS": "databricks",
"DEEPSEEK": "openai",
"ALIBABA_QWEN": "openai",
"MOONSHOT": "openai",
"ZHIPU": "openai",
"MINIMAX": "openai",
}
provider_prefix = provider_map.get(
global_config["provider"], global_config["provider"].lower()
)
model_string = f"{provider_prefix}/{global_config['model_name']}"
# Create ChatLiteLLM instance from global config
litellm_kwargs = {
"model": model_string,
"api_key": global_config["api_key"],
}
if global_config.get("api_base"):
litellm_kwargs["api_base"] = global_config["api_base"]
if global_config.get("litellm_params"):
litellm_kwargs.update(global_config["litellm_params"])
if disable_streaming:
litellm_kwargs["disable_streaming"] = True
_, litellm_kwargs = _chat_litellm_from_resolved(
conn=global_connection,
model_id=global_model["model_id"],
disable_streaming=disable_streaming,
)
from app.agents.chat.runtime.llm_config import (
SanitizedChatLiteLLM,
@ -383,80 +356,18 @@ async def get_search_space_llm_instance(
return SanitizedChatLiteLLM(**litellm_kwargs)
# Get the LLM configuration from database (NewLLMConfig)
result = await session.execute(
select(NewLLMConfig).where(
NewLLMConfig.id == llm_config_id,
NewLLMConfig.search_space_id == search_space_id,
)
)
llm_config = result.scalars().first()
if not llm_config:
model = await _get_db_model(session, llm_config_id, search_space)
if not model or not _has_capability(model, "chat"):
logger.error(
f"LLM config {llm_config_id} not found in search space {search_space_id}"
f"Chat model {llm_config_id} not found in search space {search_space_id}"
)
return None
# Build the model string for litellm
if llm_config.custom_provider:
model_string = f"{llm_config.custom_provider}/{llm_config.model_name}"
else:
# Map provider enum to litellm format
provider_map = {
"OPENAI": "openai",
"ANTHROPIC": "anthropic",
"GROQ": "groq",
"COHERE": "cohere",
"GOOGLE": "gemini",
"OLLAMA": "ollama_chat",
"MISTRAL": "mistral",
"AZURE_OPENAI": "azure",
"OPENROUTER": "openrouter",
"COMETAPI": "cometapi",
"XAI": "xai",
"BEDROCK": "bedrock",
"AWS_BEDROCK": "bedrock",
"VERTEX_AI": "vertex_ai",
"TOGETHER_AI": "together_ai",
"FIREWORKS_AI": "fireworks_ai",
"REPLICATE": "replicate",
"PERPLEXITY": "perplexity",
"ANYSCALE": "anyscale",
"DEEPINFRA": "deepinfra",
"CEREBRAS": "cerebras",
"SAMBANOVA": "sambanova",
"AI21": "ai21",
"CLOUDFLARE": "cloudflare",
"DATABRICKS": "databricks",
"DEEPSEEK": "openai",
"ALIBABA_QWEN": "openai",
"MOONSHOT": "openai",
"ZHIPU": "openai",
"MINIMAX": "openai",
"GITHUB_MODELS": "github",
}
provider_prefix = provider_map.get(
llm_config.provider.value, llm_config.provider.value.lower()
)
model_string = f"{provider_prefix}/{llm_config.model_name}"
# Create ChatLiteLLM instance
litellm_kwargs = {
"model": model_string,
"api_key": llm_config.api_key,
}
# Add optional parameters
if llm_config.api_base:
litellm_kwargs["api_base"] = llm_config.api_base
# Add any additional litellm parameters
if llm_config.litellm_params:
litellm_kwargs.update(llm_config.litellm_params)
if disable_streaming:
litellm_kwargs["disable_streaming"] = True
_, litellm_kwargs = _chat_litellm_from_resolved(
conn=model.connection,
model_id=model.model_id,
disable_streaming=disable_streaming,
)
from app.agents.chat.runtime.llm_config import (
SanitizedChatLiteLLM,
@ -474,7 +385,7 @@ async def get_search_space_llm_instance(
async def get_agent_llm(
session: AsyncSession, search_space_id: int, disable_streaming: bool = False
) -> ChatLiteLLM | ChatLiteLLMRouter | None:
"""Get the search space's agent LLM instance for chat operations."""
"""Get the search space's chat model instance."""
return await get_search_space_llm_instance(
session,
search_space_id,
@ -488,24 +399,17 @@ async def get_vision_llm(
) -> ChatLiteLLM | ChatLiteLLMRouter | None:
"""Get the search space's vision LLM instance for screenshot analysis.
Resolves from the dedicated VisionLLMConfig system:
- Auto mode (ID 0): VisionLLMRouterService
- Global (negative ID): YAML configs
- DB (positive ID): VisionLLMConfig table
Resolves from the new connection/model role bindings:
- Auto mode (ID 0): unified global/BYOK model candidate selection
- Global (negative ID): virtual GLOBAL models from YAML
- DB (positive ID): Model + Connection tables
Premium global configs are wrapped in :class:`QuotaCheckedVisionLLM`
so each ``ainvoke`` debits the search-space owner's premium credit
pool. User-owned BYOK configs and free global configs are returned
unwrapped they don't consume premium credit (issue M).
"""
from app.db import VisionLLMConfig
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
from app.services.vision_llm_router_service import (
VISION_PROVIDER_MAP,
VisionLLMRouterService,
get_global_vision_llm_config,
is_vision_auto_mode,
)
try:
result = await session.execute(
@ -516,64 +420,78 @@ async def get_vision_llm(
logger.error(f"Search space {search_space_id} not found")
return None
config_id = search_space.vision_llm_config_id
owner_user_id = search_space.user_id
# Prefer the selected chat model when it is vision-capable.
chat_model_id = search_space.chat_model_id
if chat_model_id and chat_model_id != AUTO_MODE_ID:
if chat_model_id < 0:
chat_model = get_global_model(chat_model_id)
if chat_model and _has_capability(chat_model, "vision"):
global_connection = get_global_connection(
chat_model["connection_id"]
)
if global_connection:
model_string, litellm_kwargs = _chat_litellm_from_resolved(
conn=global_connection,
model_id=chat_model["model_id"],
)
from app.agents.chat.runtime.llm_config import (
SanitizedChatLiteLLM,
)
return SanitizedChatLiteLLM(**litellm_kwargs)
else:
chat_model = await _get_db_model(session, chat_model_id, search_space)
if chat_model and _has_capability(chat_model, "vision"):
_, litellm_kwargs = _chat_litellm_from_resolved(
conn=chat_model.connection,
model_id=chat_model.model_id,
)
from app.agents.chat.runtime.llm_config import (
SanitizedChatLiteLLM,
)
return SanitizedChatLiteLLM(**litellm_kwargs)
config_id = search_space.vision_model_id
if config_id is None:
logger.error(f"No vision LLM configured for search space {search_space_id}")
return None
owner_user_id = search_space.user_id
if is_vision_auto_mode(config_id):
if not VisionLLMRouterService.is_initialized():
logger.error(
"Vision Auto mode requested but Vision LLM Router not initialized"
)
return None
try:
# Auto mode is currently treated as free at the wrapper
# level — the underlying router can dispatch to either
# premium or free YAML configs but routing decisions are
# opaque. If/when we want to bill Auto-routed vision
# calls we'd need to thread the resolved deployment's
# billing_tier back from the router. For now we keep
# parity with chat Auto, which also doesn't pre-classify.
return ChatLiteLLMRouter(
router=VisionLLMRouterService.get_router(),
streaming=True,
)
except Exception as e:
logger.error(f"Failed to create vision ChatLiteLLMRouter: {e}")
if config_id == AUTO_MODE_ID:
candidates = await auto_model_candidates(
session,
search_space_id=search_space_id,
user_id=owner_user_id,
capability="vision",
)
if not candidates:
logger.error("No vision-capable models available for Auto mode")
return None
config_id = int(
choose_auto_model_candidate(candidates, search_space_id)["id"]
)
if config_id < 0:
global_cfg = get_global_vision_llm_config(config_id)
if not global_cfg:
logger.error(f"Global vision LLM config {config_id} not found")
global_model = get_global_model(config_id)
if not global_model or not _has_capability(global_model, "vision"):
logger.error(f"Global vision model {config_id} not found")
return None
if global_cfg.get("custom_provider"):
provider_prefix = global_cfg["custom_provider"]
model_string = f"{provider_prefix}/{global_cfg['model_name']}"
else:
provider_prefix = VISION_PROVIDER_MAP.get(
global_cfg["provider"].upper(),
global_cfg["provider"].lower(),
global_connection = get_global_connection(global_model["connection_id"])
if not global_connection:
logger.error(
"Global connection %s not found for model %s",
global_model["connection_id"],
config_id,
)
model_string = f"{provider_prefix}/{global_cfg['model_name']}"
return None
litellm_kwargs = {
"model": model_string,
"api_key": global_cfg["api_key"],
}
api_base = resolve_api_base(
provider=global_cfg.get("provider"),
provider_prefix=provider_prefix,
config_api_base=global_cfg.get("api_base"),
model_string, litellm_kwargs = _chat_litellm_from_resolved(
conn=global_connection,
model_id=global_model["model_id"],
)
if api_base:
litellm_kwargs["api_base"] = api_base
if global_cfg.get("litellm_params"):
litellm_kwargs.update(global_cfg["litellm_params"])
from app.agents.chat.runtime.llm_config import (
SanitizedChatLiteLLM,
@ -581,7 +499,7 @@ async def get_vision_llm(
inner_llm = SanitizedChatLiteLLM(**litellm_kwargs)
billing_tier = str(global_cfg.get("billing_tier", "free")).lower()
billing_tier = str(global_model.get("billing_tier", "free")).lower()
if billing_tier == "premium":
return QuotaCheckedVisionLLM(
inner_llm,
@ -589,47 +507,23 @@ async def get_vision_llm(
search_space_id=search_space_id,
billing_tier=billing_tier,
base_model=model_string,
quota_reserve_tokens=global_cfg.get("quota_reserve_tokens"),
quota_reserve_tokens=global_model.get("catalog", {}).get(
"quota_reserve_tokens"
),
)
return inner_llm
# User-owned (positive ID) BYOK configs — always free.
result = await session.execute(
select(VisionLLMConfig).where(
VisionLLMConfig.id == config_id,
VisionLLMConfig.search_space_id == search_space_id,
)
)
vision_cfg = result.scalars().first()
if not vision_cfg:
model = await _get_db_model(session, config_id, search_space)
if not model or not _has_capability(model, "vision"):
logger.error(
f"Vision LLM config {config_id} not found in search space {search_space_id}"
f"Vision model {config_id} not found in search space {search_space_id}"
)
return None
if vision_cfg.custom_provider:
provider_prefix = vision_cfg.custom_provider
model_string = f"{provider_prefix}/{vision_cfg.model_name}"
else:
provider_prefix = VISION_PROVIDER_MAP.get(
vision_cfg.provider.value.upper(),
vision_cfg.provider.value.lower(),
)
model_string = f"{provider_prefix}/{vision_cfg.model_name}"
litellm_kwargs = {
"model": model_string,
"api_key": vision_cfg.api_key,
}
api_base = resolve_api_base(
provider=vision_cfg.provider.value,
provider_prefix=provider_prefix,
config_api_base=vision_cfg.api_base,
_, litellm_kwargs = _chat_litellm_from_resolved(
conn=model.connection,
model_id=model.model_id,
)
if api_base:
litellm_kwargs["api_base"] = api_base
if vision_cfg.litellm_params:
litellm_kwargs.update(vision_cfg.litellm_params)
from app.agents.chat.runtime.llm_config import (
SanitizedChatLiteLLM,

View file

@ -0,0 +1,36 @@
"""Override-aware model capability lookup."""
from __future__ import annotations
from collections.abc import Mapping
from typing import Any
CAPABILITY_FIELDS = {
"chat": "supports_chat",
"vision": "supports_image_input",
"image_gen": "supports_image_generation",
"tools": "supports_tools",
}
def _get_value(model: Any, key: str) -> Any:
if isinstance(model, Mapping):
return model.get(key)
return getattr(model, key, None)
def has_capability(model: Any, capability: str) -> bool:
field = CAPABILITY_FIELDS.get(capability)
if field is None:
return False
override = _get_value(model, "capabilities_override") or {}
if isinstance(override, Mapping) and field in override:
return bool(override[field])
if isinstance(override, Mapping) and capability in override:
return bool(override[capability])
return bool(_get_value(model, field))
__all__ = ["CAPABILITY_FIELDS", "has_capability"]

View file

@ -0,0 +1,490 @@
"""Connection verification, model discovery, and capability probing."""
from __future__ import annotations
import contextlib
import logging
from dataclasses import dataclass
from typing import Any
import anyio
import httpx
import litellm
from app.db import Connection, Model, ModelSource
from app.services.model_resolver import ensure_v1, to_litellm
from app.services.openrouter_model_normalizer import normalize_openrouter_models
from app.services.provider_registry import Transport, provider_label, spec_for
logger = logging.getLogger(__name__)
VERIFY_TIMEOUT_SECONDS = 8.0
DISCOVERY_TIMEOUT_SECONDS = 15.0
TEST_TIMEOUT_SECONDS = 30.0
@dataclass(frozen=True)
class VerifyResult:
status: str
ok: bool
message: str = ""
class ModelDiscoveryError(Exception):
"""User-correctable discovery failure for provider configuration issues."""
def _auth_headers(conn: Connection) -> dict[str, str]:
if not conn.api_key:
return {}
return {"Authorization": f"Bearer {conn.api_key}"}
def _anthropic_headers(conn: Connection) -> dict[str, str]:
headers = {"anthropic-version": "2023-06-01"}
if conn.api_key:
headers["x-api-key"] = conn.api_key
return headers
def _base_url_or_default(conn: Connection) -> str | None:
if conn.base_url:
return conn.base_url.rstrip("/")
if conn.provider == "openai":
return "https://api.openai.com/v1"
if conn.provider == "anthropic":
return "https://api.anthropic.com/v1"
return spec_for(conn.provider).default_base_url
def _docker_hint(url: str | None, exc_or_status: Any) -> str:
raw = str(exc_or_status)
if not url:
return raw
if "localhost" in url or "127.0.0.1" in url:
return (
f"{raw}. The backend is running inside Docker; localhost means the "
"backend container. Use host.docker.internal and make sure the model "
"server listens on 0.0.0.0."
)
if "host.docker.internal" in url and (
"refused" in raw.lower() or "connect" in raw.lower()
):
return (
f"{raw}. The host is reachable only if your local model server is "
"listening on 0.0.0.0. On Linux Docker, add "
"`host.docker.internal:host-gateway` to extra_hosts."
)
return raw
def _model_test_error(conn: Connection, model_id: str, exc: Exception) -> VerifyResult:
provider_name = provider_label(conn.provider)
raw = str(exc)
normalized = raw.lower()
exc_name = exc.__class__.__name__.lower()
status_code = getattr(exc, "status_code", None)
logger.info(
"Model test failed for provider=%s model=%s: %s",
conn.provider,
model_id,
raw,
)
if status_code in (401, 403) or "authentication" in exc_name or "401" in normalized:
return VerifyResult(
"AUTH_FAILED",
False,
f"Authentication failed. Check your {provider_name} credentials and try again.",
)
if status_code == 404 or "notfound" in exc_name or "not found" in normalized:
if conn.provider == "azure":
message = (
"Azure OpenAI deployment was not found. Check the deployment name, "
"API version, and endpoint."
)
else:
message = f"Model '{model_id}' was not found on {provider_name}."
return VerifyResult("NOT_FOUND", False, message)
if status_code == 429 or "ratelimit" in exc_name or "rate limit" in normalized:
return VerifyResult(
"RATE_LIMITED",
False,
f"{provider_name} rate limited the model test. Try again later.",
)
if "timeout" in exc_name or "timed out" in normalized:
return VerifyResult(
"TIMEOUT",
False,
f"{provider_name} did not respond in time. Check the endpoint and try again.",
)
if "connection" in exc_name or "connect" in normalized:
return VerifyResult(
"UNREACHABLE",
False,
_docker_hint(
_base_url_or_default(conn),
f"Could not reach {provider_name}. Check the endpoint and try again.",
),
)
return VerifyResult(
"UNREACHABLE",
False,
f"Could not test model '{model_id}' on {provider_name}. Check the credentials, endpoint, and model name.",
)
async def verify_connection(conn: Connection) -> VerifyResult:
spec = spec_for(conn.provider)
base_url = _base_url_or_default(conn)
if spec.base_url_required and not base_url:
return VerifyResult("UNREACHABLE", False, "Base URL is required.")
if spec.transport == Transport.OLLAMA and base_url:
url = f"{base_url.rstrip('/')}/api/version"
elif spec.discovery in {"openai_models", "openrouter"} and base_url:
url = f"{ensure_v1(base_url)}/models"
elif spec.discovery == "anthropic_models" and base_url:
url = f"{base_url.rstrip('/')}/models"
else:
return VerifyResult(
"OK", True, "Connection uses provider-native authentication."
)
try:
async with httpx.AsyncClient(timeout=VERIFY_TIMEOUT_SECONDS) as client:
headers = (
_anthropic_headers(conn)
if spec.auth_style == "x-api-key"
else _auth_headers(conn)
)
response = await client.get(url, headers=headers)
if response.status_code in (401, 403):
return VerifyResult("AUTH_FAILED", False, "Authentication failed.")
if response.status_code == 404:
if spec.transport == Transport.OLLAMA and url.endswith("/v1/models"):
message = "Ollama native API should not use /v1."
elif spec.transport == Transport.OPENAI_COMPATIBLE:
message = "OpenAI-compatible servers should expose /v1/models."
else:
message = "Endpoint returned 404."
return VerifyResult("NOT_FOUND", False, message)
response.raise_for_status()
return VerifyResult("OK", True, "Connection verified.")
except httpx.ConnectError as exc:
return VerifyResult("UNREACHABLE", False, _docker_hint(base_url, exc))
except httpx.TimeoutException as exc:
return VerifyResult("UNREACHABLE", False, f"Connection timed out: {exc}")
except httpx.HTTPError as exc:
return VerifyResult("UNREACHABLE", False, _docker_hint(base_url, exc))
def _discovery_error_message(conn: Connection, exc: httpx.HTTPError) -> str:
base_url = _base_url_or_default(conn)
if isinstance(exc, httpx.HTTPStatusError):
status_code = exc.response.status_code
if status_code in (401, 403):
return "Authentication failed while discovering models."
if status_code == 404:
spec = spec_for(conn.provider)
if spec.transport == Transport.OPENAI_COMPATIBLE:
return "OpenAI-compatible servers should expose /v1/models."
return "Model discovery endpoint returned 404."
return f"Model discovery failed with HTTP {status_code}."
if isinstance(exc, httpx.TimeoutException):
return f"Model discovery timed out: {exc}"
return _docker_hint(base_url, exc)
def _allowlist(conn: Connection) -> set[str]:
raw = (conn.extra or {}).get("model_ids") or []
return {str(item).strip() for item in raw if str(item).strip()}
def _litellm_info(model_string: str, model_id: str) -> dict[str, Any]:
with contextlib.suppress(Exception):
info = litellm.get_model_info(model=model_string)
if isinstance(info, dict):
return info
return (
litellm.model_cost.get(model_string) or litellm.model_cost.get(model_id) or {}
)
def _classify_from_litellm(model_string: str, model_id: str) -> dict[str, Any]:
info = _litellm_info(model_string, model_id)
mode = info.get("mode")
supports_image_input = False
supports_tools = False
with contextlib.suppress(Exception):
supports_image_input = bool(litellm.supports_vision(model=model_string))
with contextlib.suppress(Exception):
supports_tools = bool(litellm.supports_function_calling(model=model_string))
return {
"supports_chat": mode in (None, "chat", "completion", "responses"),
"max_input_tokens": info.get("max_input_tokens") or info.get("max_tokens"),
"supports_image_input": supports_image_input,
"supports_tools": supports_tools,
"supports_image_generation": mode
in {"image_generation", "image_generation_model"},
}
def derive_capabilities(
conn: Connection, model_id: str, metadata: dict | None = None
) -> dict[str, Any]:
metadata = metadata or {}
spec = spec_for(conn.provider)
model_string, _ = to_litellm(conn, model_id)
facts = _classify_from_litellm(model_string, model_id)
if spec.transport == Transport.OLLAMA:
caps = set(metadata.get("capabilities") or [])
details = metadata.get("details") or {}
facts.update(
{
"supports_chat": "embedding" not in caps,
"supports_image_input": "vision" in caps
or facts["supports_image_input"],
"supports_tools": "tools" in caps or facts["supports_tools"],
"supports_image_generation": False,
"max_input_tokens": metadata.get("context_length")
or metadata.get("num_ctx")
or details.get("context_length")
or facts["max_input_tokens"],
}
)
return facts
async def _discover_openai_shaped_models(
conn: Connection, base_url: str | None
) -> list[dict[str, Any]]:
resolved_base_url = base_url or _base_url_or_default(conn)
if not resolved_base_url:
return []
url = f"{ensure_v1(resolved_base_url)}/models"
async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client:
response = await client.get(url, headers=_auth_headers(conn))
response.raise_for_status()
results: list[dict[str, Any]] = []
for item in response.json().get("data", []):
model_id = item.get("id")
if not model_id:
continue
results.append(
{
"model_id": model_id,
"display_name": item.get("name") or model_id,
"source": ModelSource.DISCOVERED,
**derive_capabilities(conn, model_id, item),
"metadata": item,
}
)
return results
async def _discover_anthropic_models(conn: Connection) -> list[dict[str, Any]]:
base_url = _base_url_or_default(conn)
if not base_url:
return []
url = f"{base_url.rstrip('/')}/models"
async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client:
response = await client.get(url, headers=_anthropic_headers(conn))
response.raise_for_status()
results: list[dict[str, Any]] = []
for item in response.json().get("data", []):
model_id = item.get("id")
if not model_id:
continue
results.append(
{
"model_id": model_id,
"display_name": item.get("display_name") or model_id,
"source": ModelSource.DISCOVERED,
**derive_capabilities(conn, model_id, item),
"metadata": item,
}
)
return results
async def _ollama_tags_then_show(conn: Connection) -> list[dict[str, Any]]:
if not conn.base_url:
return []
base_url = conn.base_url.rstrip("/")
async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client:
response = await client.get(f"{base_url}/api/tags", headers=_auth_headers(conn))
response.raise_for_status()
models = response.json().get("models", [])
results: list[dict[str, Any]] = []
for item in models:
model_id = item.get("model") or item.get("name")
if not model_id:
continue
metadata = dict(item)
with contextlib.suppress(Exception):
show_response = await client.post(
f"{base_url}/api/show",
json={"model": model_id},
headers=_auth_headers(conn),
)
show_response.raise_for_status()
metadata.update(show_response.json())
results.append(
{
"model_id": model_id,
"display_name": item.get("name") or model_id,
"source": ModelSource.DISCOVERED,
**derive_capabilities(conn, model_id, metadata),
"metadata": metadata,
}
)
return results
async def _openrouter_models(conn: Connection) -> list[dict[str, Any]]:
base_url = _base_url_or_default(conn) or "https://openrouter.ai/api/v1"
async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client:
response = await client.get(
f"{ensure_v1(base_url)}/models", headers=_auth_headers(conn)
)
response.raise_for_status()
return normalize_openrouter_models(response.json().get("data", []))
def _litellm_static_models(conn: Connection) -> list[dict[str, Any]]:
provider = conn.provider
prefix = spec_for(provider).litellm_prefix or provider
results: list[dict[str, Any]] = []
for model_string, metadata in litellm.model_cost.items():
if not isinstance(model_string, str) or not model_string.startswith(
f"{prefix}/"
):
continue
model_id = model_string.split("/", 1)[1]
results.append(
{
"model_id": model_id,
"display_name": metadata.get("display_name") or model_id,
"source": ModelSource.DISCOVERED,
**_classify_from_litellm(model_string, model_id),
"metadata": metadata,
}
)
return results
async def _discover_bedrock_models(conn: Connection) -> list[dict[str, Any]]:
params = (conn.extra or {}).get("litellm_params", {})
region_name = params.get("aws_region_name")
if not region_name:
return []
def list_models() -> list[dict[str, Any]]:
import os
import boto3
if bearer_token := params.get("aws_bearer_token_bedrock"):
try:
os.environ["AWS_BEARER_TOKEN_BEDROCK"] = bearer_token
client = boto3.client("bedrock", region_name=region_name)
finally:
os.environ.pop("AWS_BEARER_TOKEN_BEDROCK", None)
else:
client_kwargs: dict[str, str] = {"region_name": region_name}
if params.get("aws_access_key_id"):
client_kwargs["aws_access_key_id"] = params["aws_access_key_id"]
if params.get("aws_secret_access_key"):
client_kwargs["aws_secret_access_key"] = params["aws_secret_access_key"]
client = boto3.client("bedrock", **client_kwargs)
response = client.list_foundation_models()
results: list[dict[str, Any]] = []
for item in response.get("modelSummaries", []):
model_id = item.get("modelId")
if not model_id:
continue
input_modalities = set(item.get("inputModalities") or [])
output_modalities = set(item.get("outputModalities") or [])
results.append(
{
"model_id": model_id,
"display_name": item.get("modelName") or model_id,
"source": ModelSource.DISCOVERED,
"supports_chat": "TEXT" in input_modalities
and "TEXT" in output_modalities,
"supports_image_input": "IMAGE" in input_modalities,
"supports_tools": None,
"supports_image_generation": "IMAGE" in output_modalities,
"max_input_tokens": None,
"metadata": item,
}
)
return results
return await anyio.to_thread.run_sync(list_models)
async def discover_models(conn: Connection) -> list[dict[str, Any]]:
allowlist = _allowlist(conn)
spec = spec_for(conn.provider)
try:
if spec.discovery == "ollama":
results = await _ollama_tags_then_show(conn)
elif spec.discovery == "openrouter":
results = await _openrouter_models(conn)
elif spec.discovery == "anthropic_models":
results = await _discover_anthropic_models(conn)
elif spec.discovery == "openai_models":
results = await _discover_openai_shaped_models(conn, conn.base_url)
elif spec.discovery == "bedrock_models":
results = await _discover_bedrock_models(conn)
elif spec.discovery == "static":
results = _litellm_static_models(conn)
else:
results = []
except httpx.HTTPError as exc:
raise ModelDiscoveryError(_discovery_error_message(conn, exc)) from exc
if allowlist:
results = [item for item in results if item["model_id"] in allowlist]
return results
async def test_model(conn: Connection, model: Model) -> VerifyResult:
model_string, kwargs = to_litellm(conn, model.model_id)
try:
await litellm.acompletion(
model=model_string,
messages=[{"role": "user", "content": "Hello"}],
timeout=TEST_TIMEOUT_SECONDS,
**kwargs,
)
except Exception as exc:
return _model_test_error(conn, model.model_id, exc)
model.supports_chat = True
return VerifyResult("OK", True, "Model test succeeded.")
__all__ = [
"ModelDiscoveryError",
"VerifyResult",
"derive_capabilities",
"discover_models",
"test_model",
"verify_connection",
]

View file

@ -12,6 +12,8 @@ from pathlib import Path
import httpx
from app.services.openrouter_model_normalizer import normalize_openrouter_models
logger = logging.getLogger(__name__)
OPENROUTER_API_URL = "https://openrouter.ai/api/v1/models"
@ -22,7 +24,7 @@ CACHE_TTL_SECONDS = 86400 # 24 hours
_cache: list[dict] | None = None
_cache_timestamp: float = 0
# Maps OpenRouter provider slug → our LiteLLMProvider enum value.
# Maps OpenRouter provider slug to native LiteLLM provider prefixes.
# Only providers where the model-name part (after the slash) can be
# used directly with the native provider's litellm prefix are listed.
#
@ -121,26 +123,13 @@ def _process_models(raw_models: list[dict]) -> list[dict]:
"""
processed: list[dict] = []
for model in raw_models:
model_id: str = model.get("id", "")
name: str = model.get("name", "")
context_length = model.get("context_length")
for normalized in normalize_openrouter_models(raw_models):
model_id: str = normalized["model_id"]
name: str = normalized.get("display_name") or model_id
context_length = normalized.get("max_input_tokens")
if "/" not in model_id:
continue
if not _is_text_output_model(model):
continue
if not _supports_tool_calling(model):
continue
if not _has_sufficient_context(model):
continue
if not _is_allowed_model(model):
continue
provider_slug, model_name = model_id.split("/", 1)
context_window = _format_context_length(context_length)
@ -154,19 +143,19 @@ def _process_models(raw_models: list[dict]) -> list[dict]:
}
)
# 2) Emit for the native provider when we have a mapping
native_provider = OPENROUTER_SLUG_TO_PROVIDER.get(provider_slug)
if native_provider:
# 2) Emit for the direct provider when we have a mapping
direct_provider = OPENROUTER_SLUG_TO_PROVIDER.get(provider_slug)
if direct_provider:
# Google's Gemini API only serves gemini-* models.
# Open-source models like gemma-* are NOT available through it.
if native_provider == "GOOGLE" and not model_name.startswith("gemini-"):
if direct_provider == "GOOGLE" and not model_name.startswith("gemini-"):
continue
processed.append(
{
"value": model_name,
"label": name,
"provider": native_provider,
"provider": direct_provider,
"context_window": context_window,
}
)

View file

@ -0,0 +1,94 @@
"""Single model-to-LiteLLM resolver.
All chat, vision, image-generation, validation, and Auto routing paths should
turn a Connection + Model into LiteLLM input through this module.
"""
from __future__ import annotations
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from app.db import Connection
from app.services.provider_registry import Transport, spec_for
def ensure_v1(base_url: str | None) -> str | None:
if not base_url:
return None
stripped = base_url.rstrip("/")
if stripped.endswith("/v1"):
return stripped
return f"{stripped}/v1"
def _conn_value(conn: Connection | Mapping[str, Any], key: str) -> Any:
if isinstance(conn, Mapping):
return conn.get(key)
return getattr(conn, key)
def to_litellm(
conn: Connection | Mapping[str, Any],
model_id: str,
) -> tuple[str, dict[str, Any]]:
"""Return ``(model_string, litellm_kwargs)`` for any model role."""
provider = _conn_value(conn, "provider")
base_url = _conn_value(conn, "base_url")
api_key = _conn_value(conn, "api_key")
extra = _conn_value(conn, "extra") or {}
spec = spec_for(provider)
kwargs: dict[str, Any] = {}
if api_key:
kwargs["api_key"] = api_key
prefix = spec.litellm_prefix or str(provider)
model_string = f"{prefix}/{model_id}" if prefix else model_id
if base_url:
api_base = (
ensure_v1(base_url)
if spec.transport == Transport.OPENAI_COMPATIBLE
else base_url.rstrip("/")
)
kwargs["api_base"] = api_base
if api_version := extra.get("api_version"):
kwargs["api_version"] = api_version
kwargs.update(extra.get("litellm_params", {}))
kwargs.update(extra.get("kwargs", {}))
if provider == "bedrock" and (
bearer_token := kwargs.pop("aws_bearer_token_bedrock", None)
):
kwargs["api_key"] = bearer_token
return model_string, kwargs
def native_connection_from_config(config: Mapping[str, Any]) -> dict[str, Any]:
"""Build an in-memory connection mapping from a global config."""
provider = str(
config.get("provider")
or config.get("litellm_provider")
or config.get("custom_provider")
or "openai"
)
extra: dict[str, Any] = {
"litellm_params": config.get("litellm_params") or {},
}
if config.get("api_version"):
extra["api_version"] = config.get("api_version")
return {
"provider": provider,
"base_url": config.get("api_base") or None,
"api_key": config.get("api_key") or None,
"extra": extra,
}
__all__ = [
"ensure_v1",
"native_connection_from_config",
"to_litellm",
]

View file

@ -19,6 +19,10 @@ from typing import Any
import httpx
from app.services.openrouter_model_normalizer import (
is_openrouter_image_model,
normalize_openrouter_models,
)
from app.services.quality_score import (
_HEALTH_BLEND_WEIGHT,
_HEALTH_ENRICH_CONCURRENCY,
@ -274,7 +278,7 @@ def _generate_configs(
OpenRouter's own ``openrouter/free`` meta-router is filtered out upstream
via ``_EXCLUDED_MODEL_IDS``; we don't expose a redundant auto-select layer
because our own Auto (Fastest) pin + 24 h refresh + repair logic already
because our own Auto pin + 24 h refresh + repair logic already
cover the catalogue-churn case.
"""
id_offset: int = settings.get("id_offset", -10000)
@ -292,24 +296,16 @@ def _generate_configs(
use_default: bool = settings.get("use_default_system_instructions", True)
citations_enabled: bool = settings.get("citations_enabled", True)
text_models = [
m
for m in raw_models
if _is_text_output_model(m)
and _supports_tool_calling(m)
and _has_sufficient_context(m)
and _is_compatible_provider(m)
and _is_allowed_model(m)
and "/" in m.get("id", "")
]
text_models = normalize_openrouter_models(raw_models)
configs: list[dict] = []
taken: set[int] = set()
now_ts = int(time.time())
for model in text_models:
model_id: str = model["id"]
name: str = model.get("name", model_id)
for normalized in text_models:
model = normalized.get("metadata") or {}
model_id: str = normalized["model_id"]
name: str = normalized.get("display_name") or model_id
tier = _openrouter_tier(model)
static_q = static_score_or(model, now_ts=now_ts)
@ -323,10 +319,10 @@ def _generate_configs(
"seo_enabled": seo_enabled,
"seo_slug": None,
"quota_reserve_tokens": quota_reserve_tokens,
"provider": "OPENROUTER",
"provider": "openrouter",
"model_name": model_id,
"api_key": api_key,
"api_base": "",
"api_base": "https://openrouter.ai/api/v1",
"rpm": free_rpm if tier == "free" else rpm,
"tpm": free_tpm if tier == "free" else tpm,
"litellm_params": dict(litellm_params),
@ -345,9 +341,9 @@ def _generate_configs(
# ``stream_new_chat`` as a fail-fast safety net before the
# OpenRouter request would otherwise 404 with
# ``"No endpoints found that support image input"``.
"supports_image_input": _supports_image_input(model),
"supports_image_input": bool(normalized.get("supports_image_input")),
_OPENROUTER_DYNAMIC_MARKER: True,
# Auto (Fastest) ranking metadata. ``quality_score`` is initialised
# Auto ranking metadata. ``quality_score`` is initialised
# to the static score and gets re-blended with health on the next
# ``_enrich_health`` pass (synchronous on refresh, deferred on cold
# start so startup latency is unchanged).
@ -362,11 +358,7 @@ def _generate_configs(
return configs
# ID-offset bands used to keep dynamic OpenRouter configs in their own
# namespace per surface. Image / vision get separate bands so a single
# Postgres-INTEGER cfg ID is unambiguous about which selector it belongs to.
_OPENROUTER_IMAGE_ID_OFFSET_DEFAULT = -20000
_OPENROUTER_VISION_ID_OFFSET_DEFAULT = -30000
def _generate_image_gen_configs(
@ -400,14 +392,7 @@ def _generate_image_gen_configs(
free_rpm: int = settings.get("free_rpm", 20)
litellm_params: dict = settings.get("litellm_params") or {}
image_models = [
m
for m in raw_models
if _is_image_output_model(m)
and _is_compatible_provider(m)
and _is_allowed_model(m)
and "/" in m.get("id", "")
]
image_models = [m for m in raw_models if is_openrouter_image_model(m)]
configs: list[dict] = []
taken: set[int] = set()
@ -420,14 +405,9 @@ def _generate_image_gen_configs(
"id": _stable_config_id(model_id, id_offset, taken),
"name": name,
"description": f"{name} via OpenRouter (image generation)",
"provider": "OPENROUTER",
"provider": "openrouter",
"model_name": model_id,
"api_key": api_key,
# Pin to OpenRouter's public base URL so a downstream call site
# that forgets ``resolve_api_base`` still doesn't inherit
# ``AZURE_OPENAI_ENDPOINT`` and 404 on
# ``image_generation/transformation`` (defense-in-depth, see
# ``provider_api_base`` docstring).
"api_base": "https://openrouter.ai/api/v1",
"api_version": None,
"rpm": free_rpm if tier == "free" else rpm,
@ -440,93 +420,6 @@ def _generate_image_gen_configs(
return configs
def _generate_vision_llm_configs(
raw_models: list[dict], settings: dict[str, Any]
) -> list[dict]:
"""Convert OpenRouter vision-capable LLMs into global vision-LLM config
dicts (matches the YAML shape consumed by ``vision_llm_routes``).
Filter:
- architecture.input_modalities contains "image"
- architecture.output_modalities contains "text"
- compatible provider (excluded slugs blocked)
- allowed model id (excluded list blocked)
Vision-LLM is invoked from the indexer (image extraction during
document upload) via ``langchain_litellm.ChatLiteLLM.ainvoke``, so
the chat-only ``_supports_tool_calling`` and ``_has_sufficient_context``
filters do not apply: a small-context vision model that doesn't
advertise tool-calling is still perfectly viable for "describe this
image" prompts.
"""
id_offset: int = int(
settings.get("vision_id_offset") or _OPENROUTER_VISION_ID_OFFSET_DEFAULT
)
api_key: str = settings.get("api_key", "")
rpm: int = settings.get("rpm", 200)
tpm: int = settings.get("tpm", 1_000_000)
free_rpm: int = settings.get("free_rpm", 20)
free_tpm: int = settings.get("free_tpm", 100_000)
quota_reserve_tokens: int = settings.get("quota_reserve_tokens", 4000)
litellm_params: dict = settings.get("litellm_params") or {}
vision_models = [
m
for m in raw_models
if _is_vision_input_model(m)
and _is_compatible_provider(m)
and _is_allowed_model(m)
and "/" in m.get("id", "")
]
configs: list[dict] = []
taken: set[int] = set()
for model in vision_models:
model_id: str = model["id"]
name: str = model.get("name", model_id)
tier = _openrouter_tier(model)
pricing = model.get("pricing") or {}
# Capture per-token prices so ``pricing_registration`` can
# register them with LiteLLM at startup (and so the cost
# estimator in ``estimate_call_reserve_micros`` can resolve
# them at reserve time).
try:
input_cost = float(pricing.get("prompt", 0) or 0)
except (TypeError, ValueError):
input_cost = 0.0
try:
output_cost = float(pricing.get("completion", 0) or 0)
except (TypeError, ValueError):
output_cost = 0.0
cfg: dict[str, Any] = {
"id": _stable_config_id(model_id, id_offset, taken),
"name": name,
"description": f"{name} via OpenRouter (vision)",
"provider": "OPENROUTER",
"model_name": model_id,
"api_key": api_key,
# Pin to OpenRouter's public base URL so a downstream call site
# that forgets ``resolve_api_base`` still doesn't inherit
# ``AZURE_OPENAI_ENDPOINT`` (defense-in-depth, see
# ``provider_api_base`` docstring).
"api_base": "https://openrouter.ai/api/v1",
"api_version": None,
"rpm": free_rpm if tier == "free" else rpm,
"tpm": free_tpm if tier == "free" else tpm,
"litellm_params": dict(litellm_params),
"billing_tier": tier,
"quota_reserve_tokens": quota_reserve_tokens,
"input_cost_per_token": input_cost or None,
"output_cost_per_token": output_cost or None,
_OPENROUTER_DYNAMIC_MARKER: True,
}
configs.append(cfg)
return configs
class OpenRouterIntegrationService:
"""Singleton that manages the dynamic OpenRouter model catalogue."""
@ -553,11 +446,9 @@ class OpenRouterIntegrationService:
# Cached raw catalogue from the most recent fetch. Image / vision
# emitters reuse this to avoid a second network call per surface.
self._raw_models: list[dict] = []
# Image / vision config caches (only populated when the matching
# opt-in flag is true on initialize). Refreshed in lockstep with
# the chat catalogue.
# Image config cache (only populated when the matching opt-in flag is
# true on initialize). Refreshed in lockstep with the chat catalogue.
self._image_configs: list[dict] = []
self._vision_configs: list[dict] = []
@classmethod
def get_instance(cls) -> "OpenRouterIntegrationService":
@ -592,7 +483,7 @@ class OpenRouterIntegrationService:
self._configs_by_id = {c["id"]: c for c in self._configs}
self._raw_pricing = _extract_raw_pricing(raw_models)
# Populate image / vision caches when their opt-in flag is set.
# Populate image cache when its opt-in flag is set.
# Empty otherwise so the accessors return [] without re-running
# filters every refresh.
if settings.get("image_generation_enabled"):
@ -604,15 +495,6 @@ class OpenRouterIntegrationService:
else:
self._image_configs = []
if settings.get("vision_enabled"):
self._vision_configs = _generate_vision_llm_configs(raw_models, settings)
logger.info(
"OpenRouter integration: vision LLM emission ON (%d models)",
len(self._vision_configs),
)
else:
self._vision_configs = []
self._initialized = True
tier_counts = self._tier_counts(self._configs)
@ -666,9 +548,9 @@ class OpenRouterIntegrationService:
self._configs = new_configs
self._configs_by_id = new_by_id
# Image / vision lists are atomic-swapped the same way: filter out
# Image list is atomic-swapped the same way: filter out
# the previous dynamic entries from the live config list and append
# the freshly generated ones. No-ops when the opt-in flag is off.
# the freshly generated ones. No-op when the opt-in flag is off.
if self._settings.get("image_generation_enabled"):
new_image = _generate_image_gen_configs(raw_models, self._settings)
static_image = [
@ -679,16 +561,6 @@ class OpenRouterIntegrationService:
app_config.GLOBAL_IMAGE_GEN_CONFIGS = static_image + new_image
self._image_configs = new_image
if self._settings.get("vision_enabled"):
new_vision = _generate_vision_llm_configs(raw_models, self._settings)
static_vision = [
c
for c in app_config.GLOBAL_VISION_LLM_CONFIGS
if not c.get(_OPENROUTER_DYNAMIC_MARKER)
]
app_config.GLOBAL_VISION_LLM_CONFIGS = static_vision + new_vision
self._vision_configs = new_vision
# Catalogue churn invalidates per-config "recently healthy" credit
# earned by the previous turn's preflight. Drop the whole table so
# the next turn re-probes against the freshly loaded configs.
@ -710,7 +582,7 @@ class OpenRouterIntegrationService:
)
# Re-blend health scores against the freshly fetched catalogue. Also
# re-stamps health for any YAML-curated cfg with provider==OPENROUTER
# re-stamps health for any YAML-curated cfg with provider=openrouter
# so a hand-picked dead OR model is gated like a dynamic one.
await self._enrich_health_safely(static_configs + new_configs, log_summary=True)
@ -758,7 +630,7 @@ class OpenRouterIntegrationService:
return counts
# ------------------------------------------------------------------
# Auto (Fastest) health enrichment
# Auto health enrichment
# ------------------------------------------------------------------
async def _enrich_health_safely(
@ -787,7 +659,7 @@ class OpenRouterIntegrationService:
the entire previous cycle's cache for this run.
"""
or_cfgs = [
c for c in configs if str(c.get("provider", "")).upper() == "OPENROUTER"
c for c in configs if str(c.get("provider", "")).lower() == "openrouter"
]
if not or_cfgs:
return
@ -968,17 +840,6 @@ class OpenRouterIntegrationService:
"""
return list(self._image_configs)
def get_vision_llm_configs(self) -> list[dict]:
"""Return the dynamic OpenRouter vision-LLM configs (empty list
when the ``vision_enabled`` flag is off).
Each entry exposes ``input_cost_per_token`` / ``output_cost_per_token``
so ``pricing_registration`` can teach LiteLLM the cost of these
models the same way it does for chat which keeps the billable
wrapper able to debit accurate micro-USD on a vision call.
"""
return list(self._vision_configs)
def get_raw_pricing(self) -> dict[str, dict[str, str]]:
"""Return the cached raw OpenRouter pricing map.

View file

@ -0,0 +1,123 @@
"""Shared OpenRouter model normalization.
OpenRouter metadata is richer than generic OpenAI-compatible ``/models``
responses. Keep all OpenRouter filtering and capability extraction here so
GLOBAL catalogue generation and BYOK discovery agree.
"""
from __future__ import annotations
from typing import Any
from app.db import ModelSource
MIN_CONTEXT_LENGTH = 100_000
EXCLUDED_PROVIDER_SLUGS = {"amazon"}
EXCLUDED_MODEL_IDS: set[str] = {
"openai/gpt-4-1106-preview",
"openai/gpt-4-turbo-preview",
"openai/gpt-4o:extended",
"arcee-ai/virtuoso-large",
"openai/o3-deep-research",
"openai/o4-mini-deep-research",
"openrouter/free",
}
EXCLUDED_MODEL_SUFFIXES: tuple[str, ...] = ("-deep-research",)
def is_text_output_model(model: dict[str, Any]) -> bool:
output_mods = model.get("architecture", {}).get("output_modalities", [])
return output_mods == ["text"]
def is_image_output_model(model: dict[str, Any]) -> bool:
output_mods = model.get("architecture", {}).get("output_modalities", []) or []
return "image" in output_mods
def supports_image_input(model: dict[str, Any]) -> bool:
input_mods = model.get("architecture", {}).get("input_modalities", []) or []
return "image" in input_mods
def supports_tool_calling(model: dict[str, Any]) -> bool:
supported = model.get("supported_parameters") or []
return "tools" in supported
def has_sufficient_context(model: dict[str, Any]) -> bool:
return int(model.get("context_length") or 0) >= MIN_CONTEXT_LENGTH
def is_compatible_provider(model: dict[str, Any]) -> bool:
model_id = str(model.get("id") or "")
slug = model_id.split("/", 1)[0] if "/" in model_id else ""
return slug not in EXCLUDED_PROVIDER_SLUGS
def is_allowed_model(model: dict[str, Any]) -> bool:
model_id = str(model.get("id") or "")
if model_id in EXCLUDED_MODEL_IDS:
return False
base_id = model_id.split(":")[0]
return not base_id.endswith(EXCLUDED_MODEL_SUFFIXES)
def is_openrouter_chat_model(model: dict[str, Any]) -> bool:
return (
"/" in str(model.get("id") or "")
and is_text_output_model(model)
and supports_tool_calling(model)
and has_sufficient_context(model)
and is_compatible_provider(model)
and is_allowed_model(model)
)
def is_openrouter_image_model(model: dict[str, Any]) -> bool:
return (
"/" in str(model.get("id") or "")
and is_image_output_model(model)
and is_compatible_provider(model)
and is_allowed_model(model)
)
def normalize_openrouter_models(
raw_models: list[dict[str, Any]],
) -> list[dict[str, Any]]:
normalized: list[dict[str, Any]] = []
for model in raw_models:
if not is_openrouter_chat_model(model):
continue
model_id = str(model.get("id") or "")
normalized.append(
{
"model_id": model_id,
"display_name": model.get("name") or model_id,
"source": ModelSource.DISCOVERED,
"supports_chat": True,
"max_input_tokens": model.get("context_length"),
"supports_image_input": supports_image_input(model),
"supports_tools": supports_tool_calling(model),
"supports_image_generation": False,
"metadata": model,
}
)
return normalized
__all__ = [
"MIN_CONTEXT_LENGTH",
"has_sufficient_context",
"is_allowed_model",
"is_compatible_provider",
"is_image_output_model",
"is_openrouter_chat_model",
"is_openrouter_image_model",
"is_text_output_model",
"normalize_openrouter_models",
"supports_image_input",
"supports_tool_calling",
]

View file

@ -143,21 +143,19 @@ def _register_chat_shape_configs(
sample_keys: list[str] = []
for cfg in configs:
provider = str(cfg.get("provider") or "").upper()
provider = str(cfg.get("provider") or cfg.get("litellm_provider") or "").lower()
model_name = str(cfg.get("model_name") or "").strip()
litellm_params = cfg.get("litellm_params") or {}
base_model = str(litellm_params.get("base_model") or model_name).strip()
if provider == "OPENROUTER":
if provider == "openrouter":
entry = or_pricing.get(model_name)
if entry:
input_cost = _safe_float(entry.get("prompt"))
output_cost = _safe_float(entry.get("completion"))
else:
# Vision configs from ``_generate_vision_llm_configs``
# carry their pricing inline because the OpenRouter
# raw-pricing cache is keyed by chat-catalogue model_id;
# vision flows pick up the inline values here.
# Some dynamically materialized configs can carry pricing
# inline when the raw OpenRouter cache has no matching entry.
input_cost = _safe_float(cfg.get("input_cost_per_token"))
output_cost = _safe_float(cfg.get("output_cost_per_token"))
if input_cost == 0.0 and output_cost == 0.0:
@ -189,12 +187,11 @@ def _register_chat_shape_configs(
skipped_no_pricing += 1
continue
aliases = _alias_set_for_yaml(provider, model_name, base_model)
provider_slug = "azure" if provider == "AZURE_OPENAI" else provider.lower()
count = _register(
aliases,
input_cost=input_cost,
output_cost=output_cost,
provider=provider_slug,
provider=provider,
)
if count > 0:
registered_models += 1
@ -217,9 +214,8 @@ def _register_chat_shape_configs(
def register_pricing_from_global_configs() -> None:
"""Register pricing for every known LLM deployment with LiteLLM.
Walks ``config.GLOBAL_LLM_CONFIGS`` *and* ``config.GLOBAL_VISION_LLM_CONFIGS``
so vision calls (during indexing) can resolve cost the same way chat
calls do namely:
Walks ``config.GLOBAL_LLM_CONFIGS`` so chat and vision calls can resolve
cost from the same chat-shaped deployment configs:
1. ``OPENROUTER``: pulls the cached raw pricing from
``OpenRouterIntegrationService`` (populated during its own
@ -246,10 +242,7 @@ def register_pricing_from_global_configs() -> None:
from app.config import config as app_config
chat_configs: list[dict] = list(getattr(app_config, "GLOBAL_LLM_CONFIGS", []) or [])
vision_configs: list[dict] = list(
getattr(app_config, "GLOBAL_VISION_LLM_CONFIGS", []) or []
)
if not chat_configs and not vision_configs:
if not chat_configs:
logger.info("[PricingRegistration] no global configs to register")
return
@ -268,7 +261,3 @@ def register_pricing_from_global_configs() -> None:
if chat_configs:
_register_chat_shape_configs(chat_configs, or_pricing=or_pricing, label="chat")
if vision_configs:
_register_chat_shape_configs(
vision_configs, or_pricing=or_pricing, label="vision"
)

View file

@ -1,106 +0,0 @@
"""Provider-aware ``api_base`` resolution shared by chat / image-gen / vision.
LiteLLM falls back to the module-global ``litellm.api_base`` when an
individual call doesn't pass one, which silently inherits provider-agnostic
env vars like ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE``. Without an
explicit ``api_base``, an ``openrouter/<model>`` request can end up at an
Azure endpoint and 404 with ``Resource not found`` (real reproducer:
[litellm/llms/openrouter/image_generation/transformation.py:242-263] appends
``/chat/completions`` to whatever inherited base it gets, regardless of
provider).
The chat router has had this defense for a while
(``llm_router_service.py:466-478``). This module hoists the maps + cascade
into a tiny standalone helper so vision and image-gen can share the same
source of truth without an inter-service circular import.
"""
from __future__ import annotations
PROVIDER_DEFAULT_API_BASE: dict[str, str] = {
"openrouter": "https://openrouter.ai/api/v1",
"groq": "https://api.groq.com/openai/v1",
"mistral": "https://api.mistral.ai/v1",
"perplexity": "https://api.perplexity.ai",
"xai": "https://api.x.ai/v1",
"cerebras": "https://api.cerebras.ai/v1",
"deepinfra": "https://api.deepinfra.com/v1/openai",
"fireworks_ai": "https://api.fireworks.ai/inference/v1",
"together_ai": "https://api.together.xyz/v1",
"anyscale": "https://api.endpoints.anyscale.com/v1",
"cometapi": "https://api.cometapi.com/v1",
"sambanova": "https://api.sambanova.ai/v1",
}
"""Default ``api_base`` per LiteLLM provider prefix (lowercase).
Only providers with a well-known, stable public base URL are listed
self-hosted / BYO-endpoint providers (ollama, custom, bedrock, vertex_ai,
huggingface, databricks, cloudflare, replicate) are intentionally omitted
so their existing config-driven behaviour is preserved."""
PROVIDER_KEY_DEFAULT_API_BASE: dict[str, str] = {
"DEEPSEEK": "https://api.deepseek.com/v1",
"ALIBABA_QWEN": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
"MOONSHOT": "https://api.moonshot.ai/v1",
"ZHIPU": "https://open.bigmodel.cn/api/paas/v4",
"MINIMAX": "https://api.minimax.io/v1",
}
"""Canonical provider key (uppercase) → base URL.
Used when the LiteLLM provider prefix is the generic ``openai`` shim but the
config's ``provider`` field tells us which API it actually is (DeepSeek,
Alibaba, Moonshot, Zhipu, MiniMax all use the ``openai`` prefix but each
has its own base URL)."""
def resolve_api_base(
*,
provider: str | None,
provider_prefix: str | None,
config_api_base: str | None,
) -> str | None:
"""Resolve a non-Azure-leaking ``api_base`` for a deployment.
Cascade (first non-empty wins):
1. The config's own ``api_base`` (whitespace-only treated as missing).
2. ``PROVIDER_KEY_DEFAULT_API_BASE[provider.upper()]``.
3. ``PROVIDER_DEFAULT_API_BASE[provider_prefix.lower()]``.
4. ``None`` caller should NOT set ``api_base`` and let the LiteLLM
provider integration apply its own default (e.g. AzureOpenAI's
deployment-derived URL, custom provider's per-deployment URL).
Args:
provider: The config's ``provider`` field (e.g. ``"OPENROUTER"``,
``"DEEPSEEK"``). Case-insensitive.
provider_prefix: The LiteLLM model-string prefix the same call
site builds for the model id (e.g. ``"openrouter"``,
``"groq"``). Case-insensitive.
config_api_base: ``api_base`` from the global YAML / DB row /
OpenRouter dynamic config. Empty / whitespace-only means
"missing" the resolver still applies the cascade.
Returns:
A URL string, or ``None`` if no default applies for this provider.
"""
if config_api_base and config_api_base.strip():
return config_api_base
if provider:
key_default = PROVIDER_KEY_DEFAULT_API_BASE.get(provider.upper())
if key_default:
return key_default
if provider_prefix:
prefix_default = PROVIDER_DEFAULT_API_BASE.get(provider_prefix.lower())
if prefix_default:
return prefix_default
return None
__all__ = [
"PROVIDER_DEFAULT_API_BASE",
"PROVIDER_KEY_DEFAULT_API_BASE",
"resolve_api_base",
]

View file

@ -49,51 +49,6 @@ import litellm
logger = logging.getLogger(__name__)
# Provider-name → LiteLLM model-prefix map.
#
# Owned here because ``app.services.provider_capabilities`` is the
# only edge that's safe to call from ``app.config``'s YAML loader at
# class-body init time. ``app.agents.chat.runtime.llm_config`` re-exports
# this constant under the historical ``PROVIDER_MAP`` name; placing the
# map there directly would re-introduce the
# ``app.config -> ... -> deliverables/tools/generate_image ->
# app.config`` cycle that prompted the move.
_PROVIDER_PREFIX_MAP: dict[str, str] = {
"OPENAI": "openai",
"ANTHROPIC": "anthropic",
"GROQ": "groq",
"COHERE": "cohere",
"GOOGLE": "gemini",
"OLLAMA": "ollama_chat",
"MISTRAL": "mistral",
"AZURE_OPENAI": "azure",
"OPENROUTER": "openrouter",
"XAI": "xai",
"BEDROCK": "bedrock",
"VERTEX_AI": "vertex_ai",
"TOGETHER_AI": "together_ai",
"FIREWORKS_AI": "fireworks_ai",
"DEEPSEEK": "openai",
"ALIBABA_QWEN": "openai",
"MOONSHOT": "openai",
"ZHIPU": "openai",
"GITHUB_MODELS": "github",
"REPLICATE": "replicate",
"PERPLEXITY": "perplexity",
"ANYSCALE": "anyscale",
"DEEPINFRA": "deepinfra",
"CEREBRAS": "cerebras",
"SAMBANOVA": "sambanova",
"AI21": "ai21",
"CLOUDFLARE": "cloudflare",
"DATABRICKS": "databricks",
"COMETAPI": "cometapi",
"HUGGINGFACE": "huggingface",
"MINIMAX": "openai",
"CUSTOM": "custom",
}
def _candidate_model_strings(
*,
provider: str | None,
@ -123,12 +78,7 @@ def _candidate_model_strings(
seen.add(key)
candidates.append(key)
provider_prefix: str | None = None
if provider:
provider_prefix = _PROVIDER_PREFIX_MAP.get(provider.upper(), provider.lower())
if custom_provider:
# ``custom_provider`` overrides everything for CUSTOM/proxy setups.
provider_prefix = custom_provider
provider_prefix = custom_provider or provider
primary_model = base_model or model_name
bare_model = model_name

View file

@ -0,0 +1,126 @@
"""Provider registry for model connections.
The provider string is the single public identity axis. This registry only
describes providers whose behavior differs from LiteLLM's native default.
"""
from __future__ import annotations
from dataclasses import dataclass
from enum import StrEnum
from typing import Literal
class Transport(StrEnum):
NATIVE = "NATIVE"
OPENAI_COMPATIBLE = "OPENAI_COMPATIBLE"
OLLAMA = "OLLAMA"
DiscoveryKind = Literal[
"ollama",
"openai_models",
"anthropic_models",
"bedrock_models",
"openrouter",
"static",
"none",
]
AuthStyle = Literal["bearer", "x-api-key", "none", "native"]
@dataclass(frozen=True)
class ProviderSpec:
transport: Transport
litellm_prefix: str | None
discovery: DiscoveryKind
default_base_url: str | None
base_url_required: bool
auth_style: AuthStyle
display_name: str | None = None
REGISTRY: dict[str, ProviderSpec] = {
"openai": ProviderSpec(
Transport.NATIVE, "openai", "openai_models", None, False, "bearer", "OpenAI"
),
"anthropic": ProviderSpec(
Transport.NATIVE,
"anthropic",
"anthropic_models",
None,
False,
"x-api-key",
"Anthropic",
),
"azure": ProviderSpec(
Transport.NATIVE, "azure", "static", None, True, "native", "Azure OpenAI"
),
"vertex_ai": ProviderSpec(
Transport.NATIVE, "vertex_ai", "static", None, False, "native", "Vertex AI"
),
"bedrock": ProviderSpec(
Transport.NATIVE,
"bedrock",
"bedrock_models",
None,
False,
"native",
"Amazon Bedrock",
),
"openrouter": ProviderSpec(
Transport.OPENAI_COMPATIBLE,
"openrouter",
"openrouter",
"https://openrouter.ai/api/v1",
False,
"bearer",
"OpenRouter",
),
"openai_compatible": ProviderSpec(
Transport.OPENAI_COMPATIBLE,
"openai",
"openai_models",
None,
True,
"bearer",
"OpenAI-compatible provider",
),
"lm_studio": ProviderSpec(
Transport.OPENAI_COMPATIBLE,
"openai",
"openai_models",
"http://host.docker.internal:1234/v1",
True,
"bearer",
"LM Studio",
),
"ollama_chat": ProviderSpec(
Transport.OLLAMA,
"ollama_chat",
"ollama",
"http://host.docker.internal:11434",
True,
"none",
"Ollama",
),
}
def spec_for(provider: str | None) -> ProviderSpec:
provider_key = (provider or "").strip()
return REGISTRY.get(provider_key) or ProviderSpec(
Transport.NATIVE, provider_key or "openai", "static", None, False, "native"
)
def provider_label(provider: str | None) -> str:
provider_key = (provider or "").strip()
spec = spec_for(provider_key)
if spec.display_name:
return spec.display_name
return provider_key.replace("_", " ").title() if provider_key else "Provider"
__all__ = ["REGISTRY", "ProviderSpec", "Transport", "provider_label", "spec_for"]

View file

@ -1,4 +1,4 @@
"""Pure-function quality scoring for Auto (Fastest) model selection.
"""Pure-function quality scoring for Auto model selection.
This module is import-free of any service / request-path dependencies. All
numbers are computed once during the OpenRouter refresh tick (or YAML load)
@ -108,25 +108,23 @@ PROVIDER_PRESTIGE_OR: dict[str, int] = {
# YAML provider field (the upstream API shape the operator selected).
PROVIDER_PRESTIGE_YAML: dict[str, int] = {
"AZURE_OPENAI": 50,
"OPENAI": 50,
"ANTHROPIC": 50,
"GOOGLE": 50,
"VERTEX_AI": 50,
"GEMINI": 50,
"XAI": 50,
"MISTRAL": 38,
"DEEPSEEK": 38,
"COHERE": 38,
"GROQ": 30,
"TOGETHER_AI": 28,
"FIREWORKS_AI": 28,
"PERPLEXITY": 28,
"MINIMAX": 28,
"BEDROCK": 28,
"OPENROUTER": 25,
"OLLAMA": 12,
"CUSTOM": 12,
"azure": 50,
"openai": 50,
"anthropic": 50,
"gemini": 50,
"vertex_ai": 50,
"xai": 50,
"mistral": 38,
"deepseek": 38,
"cohere": 38,
"groq": 30,
"together_ai": 28,
"fireworks_ai": 28,
"perplexity": 28,
"bedrock": 28,
"openrouter": 25,
"ollama_chat": 12,
"custom": 12,
}
@ -275,7 +273,7 @@ def static_score_yaml(cfg: dict) -> int:
listed this model. Pricing / context fall through to lazy ``litellm``
lookups; failures are silent (we just lose those sub-points).
"""
provider = str(cfg.get("provider", "")).upper()
provider = str(cfg.get("provider") or cfg.get("litellm_provider") or "").lower()
base = PROVIDER_PRESTIGE_YAML.get(provider, 15)
model_name = cfg.get("model_name") or ""

View file

@ -40,6 +40,10 @@ class TokenCallRecord:
total_tokens: int
cost_micros: int = 0
call_kind: str = "chat"
model_ref: str | None = None
model_id: str | None = None
display_name: str | None = None
provider: str | None = None
@dataclass
@ -47,6 +51,24 @@ class TurnTokenAccumulator:
"""Accumulates token usage across all LLM calls within a single user turn."""
calls: list[TokenCallRecord] = field(default_factory=list)
model_metadata: dict[str, dict[str, str | None]] = field(default_factory=dict)
def register_model_metadata(
self,
*,
model: str,
model_ref: str | None,
model_id: str | None,
display_name: str | None,
provider: str | None,
) -> None:
"""Attach resolved model metadata for later LiteLLM callback attribution."""
self.model_metadata[model] = {
"model_ref": model_ref,
"model_id": model_id,
"display_name": display_name,
"provider": provider,
}
def add(
self,
@ -57,9 +79,14 @@ class TurnTokenAccumulator:
cost_micros: int = 0,
call_kind: str = "chat",
) -> None:
metadata = self.model_metadata.get(model, {})
self.calls.append(
TokenCallRecord(
model=model,
model_ref=metadata.get("model_ref"),
model_id=metadata.get("model_id"),
display_name=metadata.get("display_name"),
provider=metadata.get("provider"),
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
@ -68,13 +95,18 @@ class TurnTokenAccumulator:
)
)
def per_message_summary(self) -> dict[str, dict[str, int]]:
def per_message_summary(self) -> dict[str, dict[str, Any]]:
"""Return token counts (and cost) grouped by model name."""
by_model: dict[str, dict[str, int]] = {}
by_model: dict[str, dict[str, Any]] = {}
for c in self.calls:
entry = by_model.setdefault(
c.model,
{
"model": c.model,
"model_ref": c.model_ref,
"model_id": c.model_id,
"display_name": c.display_name,
"provider": c.provider,
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
@ -142,6 +174,27 @@ def get_current_accumulator() -> TurnTokenAccumulator | None:
return _turn_accumulator.get()
def register_model_usage_metadata(
*,
model: str,
model_ref: str | None,
model_id: str | None,
display_name: str | None,
provider: str | None,
) -> None:
"""Register resolved model metadata with the current turn, if one exists."""
acc = _turn_accumulator.get()
if acc is None:
return
acc.register_model_metadata(
model=model,
model_ref=model_ref,
model_id=model_id,
display_name=display_name,
provider=provider,
)
@asynccontextmanager
async def scoped_turn() -> AsyncIterator[TurnTokenAccumulator]:
"""Async context manager that scopes a fresh ``TurnTokenAccumulator``

View file

@ -1,201 +0,0 @@
import logging
from typing import Any
from litellm import Router
from app.services.provider_api_base import resolve_api_base
logger = logging.getLogger(__name__)
VISION_AUTO_MODE_ID = 0
VISION_PROVIDER_MAP = {
"OPENAI": "openai",
"ANTHROPIC": "anthropic",
"GOOGLE": "gemini",
"AZURE_OPENAI": "azure",
"VERTEX_AI": "vertex_ai",
"BEDROCK": "bedrock",
"XAI": "xai",
"OPENROUTER": "openrouter",
"OLLAMA": "ollama_chat",
"GROQ": "groq",
"TOGETHER_AI": "together_ai",
"FIREWORKS_AI": "fireworks_ai",
"DEEPSEEK": "openai",
"MISTRAL": "mistral",
"CUSTOM": "custom",
}
class VisionLLMRouterService:
_instance = None
_router: Router | None = None
_model_list: list[dict] = []
_router_settings: dict = {}
_initialized: bool = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
@classmethod
def get_instance(cls) -> "VisionLLMRouterService":
if cls._instance is None:
cls._instance = cls()
return cls._instance
@classmethod
def initialize(
cls,
global_configs: list[dict],
router_settings: dict | None = None,
) -> None:
instance = cls.get_instance()
if instance._initialized:
logger.debug("Vision LLM Router already initialized, skipping")
return
model_list = []
for config in global_configs:
deployment = cls._config_to_deployment(config)
if deployment:
model_list.append(deployment)
if not model_list:
logger.warning(
"No valid vision LLM configs found for router initialization"
)
return
instance._model_list = model_list
instance._router_settings = router_settings or {}
default_settings = {
"routing_strategy": "usage-based-routing",
"num_retries": 3,
"allowed_fails": 3,
"cooldown_time": 60,
"retry_after": 5,
}
final_settings = {**default_settings, **instance._router_settings}
try:
instance._router = Router(
model_list=model_list,
routing_strategy=final_settings.get(
"routing_strategy", "usage-based-routing"
),
num_retries=final_settings.get("num_retries", 3),
allowed_fails=final_settings.get("allowed_fails", 3),
cooldown_time=final_settings.get("cooldown_time", 60),
set_verbose=False,
)
instance._initialized = True
logger.info(
"Vision LLM Router initialized with %d deployments, strategy: %s",
len(model_list),
final_settings.get("routing_strategy"),
)
except Exception as e:
logger.error(f"Failed to initialize Vision LLM Router: {e}")
instance._router = None
@classmethod
def _config_to_deployment(cls, config: dict) -> dict | None:
try:
if not config.get("model_name") or not config.get("api_key"):
return None
provider = config.get("provider", "").upper()
if config.get("custom_provider"):
provider_prefix = config["custom_provider"]
model_string = f"{provider_prefix}/{config['model_name']}"
else:
provider_prefix = VISION_PROVIDER_MAP.get(provider, provider.lower())
model_string = f"{provider_prefix}/{config['model_name']}"
litellm_params: dict[str, Any] = {
"model": model_string,
"api_key": config.get("api_key"),
}
api_base = resolve_api_base(
provider=provider,
provider_prefix=provider_prefix,
config_api_base=config.get("api_base"),
)
if api_base:
litellm_params["api_base"] = api_base
if config.get("api_version"):
litellm_params["api_version"] = config["api_version"]
if config.get("litellm_params"):
litellm_params.update(config["litellm_params"])
deployment: dict[str, Any] = {
"model_name": "auto",
"litellm_params": litellm_params,
}
if config.get("rpm"):
deployment["rpm"] = config["rpm"]
if config.get("tpm"):
deployment["tpm"] = config["tpm"]
return deployment
except Exception as e:
logger.warning(f"Failed to convert vision config to deployment: {e}")
return None
@classmethod
def get_router(cls) -> Router | None:
instance = cls.get_instance()
return instance._router
@classmethod
def is_initialized(cls) -> bool:
instance = cls.get_instance()
return instance._initialized and instance._router is not None
@classmethod
def get_model_count(cls) -> int:
instance = cls.get_instance()
return len(instance._model_list)
def is_vision_auto_mode(config_id: int | None) -> bool:
return config_id == VISION_AUTO_MODE_ID
def build_vision_model_string(
provider: str, model_name: str, custom_provider: str | None
) -> str:
if custom_provider:
return f"{custom_provider}/{model_name}"
prefix = VISION_PROVIDER_MAP.get(provider.upper(), provider.lower())
return f"{prefix}/{model_name}"
def get_global_vision_llm_config(config_id: int) -> dict | None:
from app.config import config
if config_id == VISION_AUTO_MODE_ID:
return {
"id": VISION_AUTO_MODE_ID,
"name": "Auto (Fastest)",
"provider": "AUTO",
"model_name": "auto",
"is_auto_mode": True,
}
if config_id > 0:
return None
for cfg in config.GLOBAL_VISION_LLM_CONFIGS:
if cfg.get("id") == config_id:
return cfg
return None

View file

@ -1,134 +0,0 @@
"""
Service for fetching and caching the vision-capable model list.
Reuses the same OpenRouter public API and local fallback as the LLM model
list service, but filters for models that accept image input.
"""
import json
import logging
import time
from pathlib import Path
import httpx
logger = logging.getLogger(__name__)
OPENROUTER_API_URL = "https://openrouter.ai/api/v1/models"
FALLBACK_FILE = (
Path(__file__).parent.parent / "config" / "vision_model_list_fallback.json"
)
CACHE_TTL_SECONDS = 86400 # 24 hours
_cache: list[dict] | None = None
_cache_timestamp: float = 0
OPENROUTER_SLUG_TO_VISION_PROVIDER: dict[str, str] = {
"openai": "OPENAI",
"anthropic": "ANTHROPIC",
"google": "GOOGLE",
"mistralai": "MISTRAL",
"x-ai": "XAI",
}
def _format_context_length(length: int | None) -> str | None:
if not length:
return None
if length >= 1_000_000:
return f"{length / 1_000_000:g}M"
if length >= 1_000:
return f"{length / 1_000:g}K"
return str(length)
async def _fetch_from_openrouter() -> list[dict] | None:
try:
async with httpx.AsyncClient(timeout=15) as client:
response = await client.get(OPENROUTER_API_URL)
response.raise_for_status()
data = response.json()
return data.get("data", [])
except Exception as e:
logger.warning("Failed to fetch from OpenRouter API for vision models: %s", e)
return None
def _load_fallback() -> list[dict]:
try:
with open(FALLBACK_FILE, encoding="utf-8") as f:
return json.load(f)
except Exception as e:
logger.error("Failed to load vision model fallback list: %s", e)
return []
def _is_vision_model(model: dict) -> bool:
"""Return True if the model accepts image input and outputs text."""
arch = model.get("architecture", {})
input_mods = arch.get("input_modalities", [])
output_mods = arch.get("output_modalities", [])
return "image" in input_mods and "text" in output_mods
def _process_vision_models(raw_models: list[dict]) -> list[dict]:
processed: list[dict] = []
for model in raw_models:
model_id: str = model.get("id", "")
name: str = model.get("name", "")
context_length = model.get("context_length")
if "/" not in model_id:
continue
if not _is_vision_model(model):
continue
provider_slug, model_name = model_id.split("/", 1)
context_window = _format_context_length(context_length)
processed.append(
{
"value": model_id,
"label": name,
"provider": "OPENROUTER",
"context_window": context_window,
}
)
native_provider = OPENROUTER_SLUG_TO_VISION_PROVIDER.get(provider_slug)
if native_provider:
if native_provider == "GOOGLE" and not model_name.startswith("gemini-"):
continue
processed.append(
{
"value": model_name,
"label": name,
"provider": native_provider,
"context_window": context_window,
}
)
return processed
async def get_vision_model_list() -> list[dict]:
global _cache, _cache_timestamp
if _cache is not None and (time.time() - _cache_timestamp) < CACHE_TTL_SECONDS:
return _cache
raw_models = await _fetch_from_openrouter()
if raw_models is None:
logger.info("Using fallback vision model list")
return _load_fallback()
processed = _process_vision_models(raw_models)
_cache = processed
_cache_timestamp = time.time()
return processed

View file

@ -0,0 +1,88 @@
"""Convert persisted chat content into provider-safe LangChain history.
Assistant UI parts are a UI/storage shape, not an LLM prompt shape. This module
extracts only model-safe content before prior turns are replayed to a provider.
"""
from __future__ import annotations
from typing import Any
_USER_CONTENT_TYPES = {"text", "image", "image_url"}
def _text_from_block(block: dict[str, Any]) -> str:
value = block.get("text") or block.get("content") or ""
return value if isinstance(value, str) else ""
def assistant_content_to_llm_text(content: Any) -> str:
"""Return visible assistant text, dropping reasoning/UI/provider blocks."""
if isinstance(content, str):
return content
if isinstance(content, dict):
return _text_from_block(content)
if not isinstance(content, list):
return ""
text_chunks: list[str] = []
for block in content:
if isinstance(block, str):
if block:
text_chunks.append(block)
continue
if not isinstance(block, dict):
continue
if block.get("type") == "text":
text = _text_from_block(block)
if text:
text_chunks.append(text)
return "\n".join(text_chunks)
def user_content_to_llm_content(
content: Any,
*,
allow_images: bool = True,
) -> str | list[dict[str, Any]]:
"""Return provider-safe user text/image content for LangChain."""
if isinstance(content, str):
return content
if isinstance(content, dict):
return _text_from_block(content)
if not isinstance(content, list):
return ""
parts: list[dict[str, Any]] = []
text_chunks: list[str] = []
for block in content:
if isinstance(block, str):
if block:
text_chunks.append(block)
continue
if not isinstance(block, dict):
continue
block_type = block.get("type")
if block_type not in _USER_CONTENT_TYPES:
continue
if block_type == "text":
text = _text_from_block(block)
if text:
parts.append({"type": "text", "text": text})
text_chunks.append(text)
elif allow_images and block_type == "image":
image = block.get("image")
if isinstance(image, str) and image.startswith("data:"):
parts.append({"type": "image_url", "image_url": {"url": image}})
elif allow_images and block_type == "image_url":
image_url = block.get("image_url")
if isinstance(image_url, dict):
url = image_url.get("url")
if isinstance(url, str) and url.startswith("data:"):
parts.append({"type": "image_url", "image_url": {"url": url}})
elif isinstance(image_url, str) and image_url.startswith("data:"):
parts.append({"type": "image_url", "image_url": {"url": image_url}})
if allow_images and any(part.get("type") == "image_url" for part in parts):
return parts
return "\n".join(text_chunks)

View file

@ -0,0 +1,89 @@
"""Normalize final LangChain assistant messages into assistant-ui parts.
Live streaming remains the primary source for rich, incremental UI state.
This module is only used after the graph has finished so refresh persistence
does not depend on provider-specific streaming chunk shapes.
"""
from __future__ import annotations
from collections.abc import Iterable
from typing import Any
from langchain_core.messages import AIMessage
def _text_from_content(content: Any) -> str:
if isinstance(content, str):
return content
if not isinstance(content, list):
return ""
text_parts: list[str] = []
for block in content:
if not isinstance(block, dict):
continue
if block.get("type") != "text":
continue
value = block.get("text") or block.get("content") or ""
if isinstance(value, str) and value:
text_parts.append(value)
return "".join(text_parts)
def normalize_ai_message_to_parts(
message: AIMessage | Any | None,
) -> list[dict[str, Any]]:
"""Return user-visible assistant-ui parts for a final AI message.
We intentionally do not backfill provider ``thinking`` /
``reasoning_content`` blocks here. If reasoning streamed live, the
``AssistantContentBuilder`` already captured it. If it only exists in the
final model payload, persisting it retroactively could expose content the
UI never showed during the turn.
"""
if message is None:
return []
text = _text_from_content(getattr(message, "content", None)).strip()
if not text:
return []
return [{"type": "text", "text": text}]
def last_ai_message(messages: Iterable[Any] | None) -> AIMessage | Any | None:
if messages is None:
return None
for message in reversed(list(messages)):
if isinstance(message, AIMessage):
return message
if getattr(message, "type", None) == "ai":
return message
return None
def final_assistant_parts_from_messages(
messages: Iterable[Any] | None,
) -> list[dict[str, Any]]:
return normalize_ai_message_to_parts(last_ai_message(messages))
def has_non_empty_text_part(parts: Iterable[dict[str, Any]]) -> bool:
return any(
part.get("type") == "text"
and isinstance(part.get("text"), str)
and bool(part.get("text", "").strip())
for part in parts
)
def merge_streamed_and_final_parts(
streamed_parts: list[dict[str, Any]],
final_parts: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Use final-state text only when streaming captured no answer text."""
if has_non_empty_text_part(streamed_parts):
return streamed_parts
if not has_non_empty_text_part(final_parts):
return streamed_parts
return [*streamed_parts, *final_parts]

View file

@ -16,6 +16,9 @@ from app.agents.chat.multi_agent_chat.main_agent.middleware.kb_persistence impor
)
from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode
from app.services.new_streaming_service import VercelStreamingService
from app.tasks.chat.message_parts_normalizer import (
final_assistant_parts_from_messages,
)
from app.tasks.chat.streaming.contract.file_contract import (
contract_enforcement_active,
evaluate_file_contract_outcome,
@ -75,6 +78,9 @@ async def stream_agent_events(
state = await agent.aget_state(config)
state_values = getattr(state, "values", {}) or {}
result.final_message_parts = final_assistant_parts_from_messages(
state_values.get("messages")
)
# Safety net: if astream_events was cancelled before
# KnowledgeBasePersistenceMiddleware.aafter_agent ran, any staged work

View file

@ -12,6 +12,7 @@ from app.agents.chat.multi_agent_chat.main_agent.middleware.busy_mutex import (
is_cancel_requested,
)
from app.agents.chat.runtime.errors import BusyError
from app.services.llm_error_adapter import LLMErrorCategory, adapt_llm_exception
TURN_CANCELLING_INITIAL_DELAY_MS = 200
TURN_CANCELLING_BACKOFF_FACTOR = 2
@ -102,6 +103,9 @@ def _extract_provider_error_code(parsed: dict[str, Any] | None) -> int | None:
def is_provider_rate_limited(exc: BaseException) -> bool:
"""Return True if the exception looks like an upstream HTTP 429 / rate limit."""
if adapt_llm_exception(exc).category is LLMErrorCategory.RATE_LIMITED:
return True
raw = str(exc)
lowered = raw.lower()
if "ratelimit" in type(exc).__name__.lower():
@ -131,6 +135,85 @@ def is_provider_rate_limited(exc: BaseException) -> bool:
)
def _provider_error_extra(adapted: Any) -> dict[str, Any] | None:
extra: dict[str, Any] = {"provider_error_category": adapted.category.value}
if adapted.provider_status_code is not None:
extra["provider_status_code"] = adapted.provider_status_code
if adapted.provider_error_type:
extra["provider_error_type"] = adapted.provider_error_type
return extra
def _classify_provider_exception(
exc: Exception,
) -> (
tuple[str, str, Literal["info", "warn", "error"], bool, str, dict[str, Any] | None]
| None
):
adapted = adapt_llm_exception(exc)
if adapted.category is LLMErrorCategory.RATE_LIMITED:
return (
"rate_limited",
"RATE_LIMITED",
"warn",
True,
"This model is temporarily rate-limited. Please try again in a few seconds or switch models.",
_provider_error_extra(adapted),
)
if adapted.category in {
LLMErrorCategory.AUTH_FAILED,
LLMErrorCategory.PERMISSION_DENIED,
}:
return (
"model_auth_failed",
"MODEL_AUTH_FAILED",
"warn",
True,
"This model's API key is invalid or expired. Switch models, or update the API key.",
_provider_error_extra(adapted),
)
if adapted.category is LLMErrorCategory.MODEL_NOT_FOUND:
return (
"model_not_found",
"MODEL_NOT_FOUND",
"warn",
True,
"The selected model is unavailable or no longer exists. Switch to another model and try again.",
_provider_error_extra(adapted),
)
if adapted.category is LLMErrorCategory.CONTEXT_LIMIT:
return (
"model_context_limit",
"MODEL_CONTEXT_LIMIT",
"warn",
True,
"This request is too large for the selected model. Try a model with a larger context window or reduce the input.",
_provider_error_extra(adapted),
)
if adapted.category in {
LLMErrorCategory.TIMEOUT,
LLMErrorCategory.PROVIDER_UNAVAILABLE,
LLMErrorCategory.BAD_GATEWAY,
LLMErrorCategory.CONNECTION_FAILED,
LLMErrorCategory.SERVER_ERROR,
}:
return (
"model_provider_unavailable",
"MODEL_PROVIDER_UNAVAILABLE",
"warn",
True,
"The selected model provider is temporarily unavailable. Please try again or switch models.",
_provider_error_extra(adapted),
)
return None
def classify_stream_exception(
exc: Exception,
*,
@ -167,15 +250,9 @@ def classify_stream_exception(
None,
)
if is_provider_rate_limited(exc):
return (
"rate_limited",
"RATE_LIMITED",
"warn",
True,
"This model is temporarily rate-limited. Please try again in a few seconds or switch models.",
None,
)
provider_classification = _classify_provider_exception(exc)
if provider_classification is not None:
return provider_classification
return (
"server_error",

View file

@ -80,7 +80,6 @@ async def _generate_title(
from litellm import acompletion
from app.services.llm_router_service import LLMRouterService
from app.services.provider_api_base import resolve_api_base
from app.services.token_tracking_service import _turn_accumulator
# Excludes this turn's own assistant row (pre-written by
@ -125,26 +124,12 @@ async def _generate_title(
router = LLMRouterService.get_router()
response = await router.acompletion(model="auto", messages=messages)
else:
# Apply the same ``api_base`` cascade chat / vision / image-gen
# call sites use so we never inherit ``litellm.api_base``
# (commonly set by ``AZURE_OPENAI_ENDPOINT``) when the chat
# config itself ships an empty ``api_base``. Without this the
# title-gen on an OpenRouter chat config would 404 against the
# inherited Azure endpoint — see ``provider_api_base`` for the
# same bug repro on the image-gen / vision paths.
raw_model = getattr(llm, "model", "") or ""
provider_prefix = raw_model.split("/", 1)[0] if "/" in raw_model else None
provider_value = agent_config.provider if agent_config is not None else None
title_api_base = resolve_api_base(
provider=provider_value,
provider_prefix=provider_prefix,
config_api_base=getattr(llm, "api_base", None),
)
response = await acompletion(
model=raw_model,
messages=messages,
api_key=getattr(llm, "api_key", None),
api_base=title_api_base,
api_base=getattr(llm, "api_base", None),
)
usage_info = None

View file

@ -53,6 +53,7 @@ async def finalize_assistant_message(
):
return
from app.tasks.chat.message_parts_normalizer import merge_streamed_and_final_parts
from app.tasks.chat.persistence import finalize_assistant_turn
builder_stats: dict[str, int] | None = None
@ -74,6 +75,10 @@ async def finalize_assistant_message(
"text": stream_result.accumulated_text or "",
}
]
content_payload = merge_streamed_and_final_parts(
content_payload,
stream_result.final_message_parts,
)
if builder_stats is not None:
_perf_log.info(

View file

@ -1,8 +1,8 @@
"""Load an LLM + AgentConfig bundle for a given config id.
Handles both code paths uniformly:
- ``config_id >= 0`` database-backed ``NewLLMConfig`` row (per-user/per-space).
- ``config_id < 0`` YAML-defined global LLM config (built-in defaults).
- ``config_id > 0`` database-backed model-connection ``Model`` row.
- ``config_id < 0`` virtual global model materialized from YAML/OpenRouter.
Returns ``(llm, agent_config, error_message)``; on success ``error_message`` is
``None``. The caller emits the friendly SSE error frame.
@ -12,15 +12,78 @@ from __future__ import annotations
from typing import Any
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.agents.chat.runtime.llm_config import (
AgentConfig,
create_chat_litellm_from_agent_config,
create_chat_litellm_from_config,
load_agent_config,
load_global_llm_config_by_id,
SanitizedChatLiteLLM,
)
from app.config import config
from app.db import Model, SearchSpace
from app.services.model_capabilities import has_capability
from app.services.model_resolver import to_litellm
from app.services.token_tracking_service import register_model_usage_metadata
def _agent_config_from_resolved(
*,
config_id: int,
config_name: str | None,
provider: str,
model_name: str,
api_key: str | None,
api_base: str | None,
litellm_params: dict | None,
supports_image_input: bool,
billing_tier: str = "free",
) -> AgentConfig:
return AgentConfig(
provider=provider,
model_name=model_name,
api_key=api_key or "",
api_base=api_base,
custom_provider=None,
litellm_params=litellm_params,
config_id=config_id,
config_name=config_name,
is_auto_mode=False,
billing_tier=billing_tier,
is_premium=billing_tier == "premium",
supports_image_input=supports_image_input,
)
async def _load_search_space(
session: AsyncSession, search_space_id: int
) -> SearchSpace | None:
result = await session.execute(
select(SearchSpace).where(SearchSpace.id == search_space_id)
)
return result.scalars().first()
async def _load_db_model(
session: AsyncSession,
*,
model_id: int,
search_space: SearchSpace,
) -> Model | None:
result = await session.execute(
select(Model)
.options(selectinload(Model.connection))
.where(Model.id == model_id, Model.enabled.is_(True))
)
model = result.scalars().first()
if not model or not model.connection or not model.connection.enabled:
return None
conn = model.connection
if conn.search_space_id is not None and conn.search_space_id != search_space.id:
return None
if conn.user_id is not None and conn.user_id != search_space.user_id:
return None
return model
async def load_llm_bundle(
@ -29,29 +92,89 @@ async def load_llm_bundle(
config_id: int,
search_space_id: int,
) -> tuple[Any, AgentConfig | None, str | None]:
if config_id >= 0:
loaded_agent_config = await load_agent_config(
session=session,
config_id=config_id,
search_space_id=search_space_id,
search_space = await _load_search_space(session, search_space_id)
if not search_space:
return None, None, f"Search space {search_space_id} not found"
if config_id > 0:
model = await _load_db_model(
session,
model_id=config_id,
search_space=search_space,
)
if not loaded_agent_config:
if not model or not has_capability(model, "chat"):
return (
None,
None,
f"Failed to load NewLLMConfig with id {config_id}",
f"Failed to load chat model with id {config_id}",
)
model_string, litellm_kwargs = to_litellm(model.connection, model.model_id)
display_name = model.display_name or model.model_id
provider = model.connection.provider or ""
register_model_usage_metadata(
model=model_string,
model_ref=f"db:{model.id}",
model_id=model.model_id,
display_name=display_name,
provider=provider,
)
agent_config = _agent_config_from_resolved(
config_id=config_id,
config_name=display_name,
provider=provider,
model_name=model.model_id,
api_key=model.connection.api_key,
api_base=model.connection.base_url,
litellm_params=(model.connection.extra or {}).get("litellm_params"),
supports_image_input=has_capability(model, "vision"),
billing_tier="free",
)
return (
create_chat_litellm_from_agent_config(loaded_agent_config),
loaded_agent_config,
SanitizedChatLiteLLM(model=model_string, **litellm_kwargs),
agent_config,
None,
)
loaded_llm_config = load_global_llm_config_by_id(config_id)
if not loaded_llm_config:
return None, None, f"Failed to load LLM config with id {config_id}"
return (
create_chat_litellm_from_config(loaded_llm_config),
AgentConfig.from_yaml_config(loaded_llm_config),
global_model = next(
(m for m in config.GLOBAL_MODELS if m.get("id") == config_id), None
)
if not global_model or not has_capability(global_model, "chat"):
return None, None, f"Failed to load global chat model with id {config_id}"
global_connection = next(
(
c
for c in config.GLOBAL_CONNECTIONS
if c.get("id") == global_model.get("connection_id")
),
None,
)
if not global_connection:
return None, None, f"Failed to load global connection for model {config_id}"
model_string, litellm_kwargs = to_litellm(
global_connection, global_model["model_id"]
)
display_name = global_model.get("display_name") or global_model.get("model_id")
provider = global_connection.get("provider") or ""
register_model_usage_metadata(
model=model_string,
model_ref=f"global:{config_id}",
model_id=global_model["model_id"],
display_name=display_name,
provider=provider,
)
agent_config = _agent_config_from_resolved(
config_id=config_id,
config_name=display_name,
provider=provider,
model_name=global_model["model_id"],
api_key=global_connection.get("api_key"),
api_base=global_connection.get("base_url"),
litellm_params=(global_connection.get("extra") or {}).get("litellm_params"),
supports_image_input=has_capability(global_model, "vision"),
billing_tier=str(global_model.get("billing_tier", "free")).lower(),
)
return (
SanitizedChatLiteLLM(model=model_string, **litellm_kwargs),
agent_config,
None,
)

View file

@ -35,3 +35,7 @@ class StreamResult:
# (``StreamResult`` is logged in some error branches) from dumping a
# potentially-large parts list.
content_builder: Any | None = field(default=None, repr=False)
# User-visible assistant message parts derived from the final LangGraph
# state. Used after streaming completes as a provider-agnostic persistence
# backfill when no text chunks reached the live stream.
final_message_parts: list[dict[str, Any]] = field(default_factory=list)

View file

@ -18,6 +18,11 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.tasks.chat.llm_history_normalizer import (
assistant_content_to_llm_text,
user_content_to_llm_content,
)
if TYPE_CHECKING:
from app.db import ChatVisibility
@ -95,17 +100,28 @@ async def bootstrap_history_from_db(
langchain_messages: list[HumanMessage | AIMessage] = []
for msg in db_messages:
text_content = extract_text_content(msg.content)
if not text_content:
continue
if msg.role == "user":
user_content = user_content_to_llm_content(
msg.content,
allow_images=False,
)
if not user_content:
continue
if is_shared:
author_name = (
msg.author.display_name if msg.author else None
) or "A team member"
text_content = f"**[{author_name}]:** {text_content}"
langchain_messages.append(HumanMessage(content=text_content))
if isinstance(user_content, str):
user_content = f"**[{author_name}]:** {user_content}"
elif user_content and user_content[0].get("type") == "text":
user_content[0] = {
**user_content[0],
"text": f"**[{author_name}]:** {user_content[0].get('text', '')}",
}
langchain_messages.append(HumanMessage(content=user_content))
elif msg.role == "assistant":
langchain_messages.append(AIMessage(content=text_content))
assistant_text = assistant_content_to_llm_text(msg.content)
if assistant_text:
langchain_messages.append(AIMessage(content=assistant_text))
return langchain_messages

View file

@ -55,7 +55,6 @@ from app.services.openrouter_integration_service import ( # noqa: E402
_OPENROUTER_DYNAMIC_MARKER,
OpenRouterIntegrationService,
)
from app.services.provider_api_base import resolve_api_base # noqa: E402
from app.services.provider_capabilities import ( # noqa: E402
derive_supports_image_input,
is_known_text_only_chat_model,
@ -154,13 +153,13 @@ def _probe_chat_capability(cfg: dict) -> tuple[bool, str]:
litellm_params.get("base_model") if isinstance(litellm_params, dict) else None
)
cap = derive_supports_image_input(
provider=cfg.get("provider"),
provider=cfg.get("litellm_provider"),
model_name=cfg.get("model_name"),
base_model=base_model,
custom_provider=cfg.get("custom_provider"),
)
block = is_known_text_only_chat_model(
provider=cfg.get("provider"),
provider=cfg.get("litellm_provider"),
model_name=cfg.get("model_name"),
base_model=base_model,
custom_provider=cfg.get("custom_provider"),
@ -179,11 +178,7 @@ def _probe_chat_capability(cfg: dict) -> tuple[bool, str]:
def _build_chat_model_string(cfg: dict) -> str:
if cfg.get("custom_provider"):
return f"{cfg['custom_provider']}/{cfg['model_name']}"
from app.services.provider_capabilities import _PROVIDER_PREFIX_MAP
prefix = _PROVIDER_PREFIX_MAP.get(
(cfg.get("provider") or "").upper(), (cfg.get("provider") or "").lower()
)
prefix = cfg.get("litellm_provider") or "openai"
return f"{prefix}/{cfg['model_name']}"
@ -195,11 +190,6 @@ def _build_chat_model_string(cfg: dict) -> str:
async def _live_chat_image_call(cfg: dict) -> tuple[bool, str]:
"""Send a 1x1 PNG + `reply with one word: ok` to the chat config."""
model_string = _build_chat_model_string(cfg)
api_base = resolve_api_base(
provider=cfg.get("provider"),
provider_prefix=model_string.split("/", 1)[0],
config_api_base=cfg.get("api_base") or None,
)
kwargs: dict[str, Any] = {
"model": model_string,
"api_key": cfg.get("api_key"),
@ -218,8 +208,8 @@ async def _live_chat_image_call(cfg: dict) -> tuple[bool, str]:
"max_tokens": 16,
"timeout": 60,
}
if api_base:
kwargs["api_base"] = api_base
if cfg.get("api_base"):
kwargs["api_base"] = cfg["api_base"]
if cfg.get("litellm_params"):
# Strip pricing keys — they're tracking-only and confuse some
# provider validators (e.g. azure/openai reject unknown kwargs
@ -257,20 +247,11 @@ _IMAGE_GEN_PROMPTS: tuple[str, ...] = (
async def _live_image_gen_call(cfg: dict) -> tuple[bool, str]:
"""Generate one tiny image to verify the deployment is reachable."""
from app.services.provider_capabilities import _PROVIDER_PREFIX_MAP
if cfg.get("custom_provider"):
prefix = cfg["custom_provider"]
else:
prefix = _PROVIDER_PREFIX_MAP.get(
(cfg.get("provider") or "").upper(), (cfg.get("provider") or "").lower()
)
prefix = cfg.get("litellm_provider") or "openai"
model_string = f"{prefix}/{cfg['model_name']}"
api_base = resolve_api_base(
provider=cfg.get("provider"),
provider_prefix=prefix,
config_api_base=cfg.get("api_base") or None,
)
base_kwargs: dict[str, Any] = {
"model": model_string,
"api_key": cfg.get("api_key"),
@ -278,8 +259,8 @@ async def _live_image_gen_call(cfg: dict) -> tuple[bool, str]:
"size": "1024x1024",
"timeout": 120,
}
if api_base:
base_kwargs["api_base"] = api_base
if cfg.get("api_base"):
base_kwargs["api_base"] = cfg["api_base"]
if cfg.get("api_version"):
base_kwargs["api_version"] = cfg["api_version"]
if cfg.get("litellm_params"):
@ -349,31 +330,6 @@ async def probe_chat_configs(report: Report, *, live: bool) -> None:
report.add(result)
async def probe_vision_configs(report: Report, *, live: bool) -> None:
print("\n[vision configs from global_vision_llm_configs (YAML-static)]")
for cfg in config.GLOBAL_VISION_LLM_CONFIGS:
if _is_or_dynamic(cfg):
continue
result = ProbeResult(
label=str(cfg.get("name") or cfg.get("model_name")),
surface="vision",
config_id=cfg.get("id"),
)
# For vision configs, capability is implied — they're in the
# dedicated vision pool. Run the same resolver to flag any
# surprise disagreement.
cap_ok, cap_note = _probe_chat_capability(cfg)
result.capability_ok = cap_ok
result.capability_note = cap_note
if live:
t0 = time.perf_counter()
ok, note = await _live_chat_image_call(cfg)
result.live_ok = ok
result.live_note = note
result.duration_s = time.perf_counter() - t0
report.add(result)
async def probe_image_gen_configs(report: Report, *, live: bool) -> None:
print(
"\n[image generation configs from global_image_generation_configs (YAML-static)]"
@ -399,7 +355,7 @@ async def probe_image_gen_configs(report: Report, *, live: bool) -> None:
async def probe_openrouter_catalog(report: Report, *, live: bool) -> None:
"""Sample one chat (vision-capable), one vision, one image-gen model
"""Sample chat/vision-capable and image-gen models
from the live OpenRouter catalogue. Doesn't iterate the full pool
(would be hundreds of probes); just validates the integration end-
to-end on a representative model from each surface."""
@ -424,9 +380,6 @@ async def probe_openrouter_catalog(report: Report, *, live: bool) -> None:
for c in config.GLOBAL_LLM_CONFIGS
if c.get("provider") == "OPENROUTER" and c.get("supports_image_input")
]
or_vision = [
c for c in config.GLOBAL_VISION_LLM_CONFIGS if c.get("provider") == "OPENROUTER"
]
or_image_gen = [
c for c in config.GLOBAL_IMAGE_GEN_CONFIGS if c.get("provider") == "OPENROUTER"
]
@ -446,11 +399,6 @@ async def probe_openrouter_catalog(report: Report, *, live: bool) -> None:
("or-chat", _pick_first(or_chat, "anthropic/claude")),
("or-chat", _pick_first(or_chat, "google/gemini-2.5-flash")),
]
vision_picks = [
("or-vision", _pick_first(or_vision, "openai/gpt-4o")),
("or-vision", _pick_first(or_vision, "anthropic/claude")),
("or-vision", _pick_first(or_vision, "google/gemini-2.5-flash")),
]
image_picks = [
("or-image", _pick_first(or_image_gen, "google/gemini-2.5-flash-image")),
# OpenRouter publishes OpenAI image models as ``openai/gpt-5-image*``
@ -460,11 +408,11 @@ async def probe_openrouter_catalog(report: Report, *, live: bool) -> None:
]
print(
f" catalog: chat={len(or_chat)} vision={len(or_vision)} image_gen={len(or_image_gen)} "
f" catalog: chat_vision={len(or_chat)} image_gen={len(or_image_gen)} "
f"(service initialized={service.is_initialized() if hasattr(service, 'is_initialized') else 'n/a'})"
)
for surface, picked in chat_picks + vision_picks + image_picks:
for surface, picked in chat_picks + image_picks:
if not picked:
report.add(
ProbeResult(
@ -505,7 +453,6 @@ async def probe_openrouter_catalog(report: Report, *, live: bool) -> None:
async def main(args: argparse.Namespace) -> int:
print("Loaded global configs:")
print(f" chat: {len(config.GLOBAL_LLM_CONFIGS)} entries")
print(f" vision: {len(config.GLOBAL_VISION_LLM_CONFIGS)} entries")
print(f" image-gen: {len(config.GLOBAL_IMAGE_GEN_CONFIGS)} entries")
print(f" OR settings present: {bool(config.OPENROUTER_INTEGRATION_SETTINGS)}")
@ -526,8 +473,6 @@ async def main(args: argparse.Namespace) -> int:
report = Report()
if not args.skip_chat:
await probe_chat_configs(report, live=args.live)
if not args.skip_vision:
await probe_vision_configs(report, live=args.live)
if not args.skip_image_gen:
await probe_image_gen_configs(report, live=args.live)
if not args.skip_openrouter:
@ -547,7 +492,6 @@ def _parse_args() -> argparse.Namespace:
)
parser.set_defaults(live=True)
parser.add_argument("--skip-chat", action="store_true")
parser.add_argument("--skip-vision", action="store_true")
parser.add_argument("--skip-image-gen", action="store_true")
parser.add_argument("--skip-openrouter", action="store_true")
return parser.parse_args()

View file

@ -19,7 +19,7 @@
# so the resolved auto-pin id is never sent to a real LLM provider.
# The values below only need to pass
# auto_model_pin_service._is_usable_global_config()
# which requires id / model_name / provider / api_key all truthy.
# which requires id / model_name / litellm_provider / api_key all truthy.
#
# Why TWO entries (premium + free):
# auto_model_pin_service.resolve_or_get_pinned_llm_config_id() splits
@ -44,9 +44,10 @@ global_llm_configs:
anonymous_enabled: false
seo_enabled: false
quality_score: 1.0
provider: "OPENAI"
litellm_provider: "openai"
model_name: "fake-e2e-model-premium"
api_key: "fake-e2e-api-key-not-for-production"
api_base: "https://api.openai.com/v1"
supports_image_input: false
quota_reserve_tokens: 1024
rpm: 1000
@ -60,9 +61,10 @@ global_llm_configs:
anonymous_enabled: false
seo_enabled: false
quality_score: 1.0
provider: "OPENAI"
litellm_provider: "openai"
model_name: "fake-e2e-model-free"
api_key: "fake-e2e-api-key-not-for-production"
api_base: "https://api.openai.com/v1"
supports_image_input: false
quota_reserve_tokens: 1024
rpm: 1000

View file

@ -0,0 +1,39 @@
"""Regression tests for model-boundary message sanitization."""
from __future__ import annotations
import pytest
from langchain_core.messages import AIMessage
from app.agents.chat.runtime.llm_config import _sanitize_messages
pytestmark = pytest.mark.unit
def test_sanitize_messages_strips_provider_specific_thinking_blocks() -> None:
original = AIMessage(
content=[
{"type": "thinking", "thinking": "private reasoning"},
{"type": "text", "text": "visible answer"},
]
)
sanitized = _sanitize_messages([original])
assert sanitized[0].content == "visible answer"
assert original.content == [
{"type": "thinking", "thinking": "private reasoning"},
{"type": "text", "text": "visible answer"},
]
def test_sanitize_messages_sets_tool_only_ai_content_to_none() -> None:
message = AIMessage(
content="",
tool_calls=[{"name": "search", "args": {"q": "x"}, "id": "call_1"}],
)
sanitized = _sanitize_messages([message])
assert sanitized[0].content is None
assert message.content == ""

View file

@ -1,6 +1,6 @@
"""Lock the runtime model-policy backstop in ``build_dependencies``.
Automations resolve their LLM from the *captured* ``agent_llm_id`` snapshot (so
Automations resolve their LLM from the *captured* ``chat_model_id`` snapshot (so
runs are insulated from later chat/search-space model changes), and the model
policy is re-checked at run time so a captured model that is no longer billable
fails the run clearly. When no snapshot is present, resolution falls back to the
@ -45,10 +45,10 @@ def patched_side_effects(monkeypatch: pytest.MonkeyPatch):
return None
async def test_build_dependencies_resolves_captured_agent_llm_id(
async def test_build_dependencies_resolves_captured_chat_model_id(
monkeypatch: pytest.MonkeyPatch, patched_side_effects
) -> None:
"""The bundle loads with the *captured* ``agent_llm_id``, not the live search space."""
"""The bundle loads with the *captured* ``chat_model_id``, not the live search space."""
captured: dict[str, Any] = {}
async def _fake_load(_session, *, config_id, search_space_id):
@ -67,13 +67,13 @@ async def test_build_dependencies_resolves_captured_agent_llm_id(
lambda _ss: pytest.fail("search-space policy should not run on captured path"),
)
search_space = SimpleNamespace(agent_llm_id=-99)
search_space = SimpleNamespace(chat_model_id=-99)
result = await build_dependencies(
session=_FakeSession(search_space),
search_space_id=42,
agent_llm_id=-7,
image_generation_config_id=5,
vision_llm_config_id=-1,
chat_model_id=-7,
image_gen_model_id=5,
vision_model_id=-1,
)
assert captured == {"config_id": -7, "search_space_id": 42}
@ -98,17 +98,17 @@ async def test_build_dependencies_validates_captured_ids(
monkeypatch.setattr(deps_mod, "load_llm_bundle", _fake_load)
await build_dependencies(
session=_FakeSession(SimpleNamespace(agent_llm_id=0)),
session=_FakeSession(SimpleNamespace(chat_model_id=0)),
search_space_id=42,
agent_llm_id=-7,
image_generation_config_id=5,
vision_llm_config_id=-1,
chat_model_id=-7,
image_gen_model_id=5,
vision_model_id=-1,
)
assert seen == {
"agent_llm_id": -7,
"image_generation_config_id": 5,
"vision_llm_config_id": -1,
"chat_model_id": -7,
"image_gen_model_id": 5,
"vision_model_id": -1,
}
@ -119,7 +119,7 @@ async def test_build_dependencies_raises_on_captured_policy_violation(
def _raise(**_kw):
raise AutomationModelPolicyError(
[{"kind": "image", "config_id": -2, "reason": "free model"}]
[{"kind": "image", "model_id": -2, "reason": "free model"}]
)
monkeypatch.setattr(deps_mod, "assert_models_billable", _raise)
@ -131,11 +131,11 @@ async def test_build_dependencies_raises_on_captured_policy_violation(
with pytest.raises(DependencyError):
await build_dependencies(
session=_FakeSession(SimpleNamespace(agent_llm_id=-7)),
session=_FakeSession(SimpleNamespace(chat_model_id=-7)),
search_space_id=42,
agent_llm_id=-7,
image_generation_config_id=-2,
vision_llm_config_id=-1,
chat_model_id=-7,
image_gen_model_id=-2,
vision_model_id=-1,
)
@ -157,7 +157,7 @@ async def test_build_dependencies_falls_back_to_search_space(
lambda **_kw: pytest.fail("captured policy should not run on fallback path"),
)
search_space = SimpleNamespace(agent_llm_id=-7)
search_space = SimpleNamespace(chat_model_id=-7)
result = await build_dependencies(
session=_FakeSession(search_space), search_space_id=42
)

View file

@ -28,9 +28,9 @@ def _run() -> SimpleNamespace:
def test_build_action_ctx_propagates_captured_models() -> None:
"""``definition.models`` flows onto the ActionContext model fields."""
models = AutomationModels(
agent_llm_id=-1,
image_generation_config_id=5,
vision_llm_config_id=-1,
chat_model_id=-1,
image_gen_model_id=5,
vision_model_id=-1,
)
ctx = _build_action_ctx(
cast(AsyncSession, None),
@ -40,9 +40,9 @@ def test_build_action_ctx_propagates_captured_models() -> None:
)
assert ctx.search_space_id == 42
assert ctx.agent_llm_id == -1
assert ctx.image_generation_config_id == 5
assert ctx.vision_llm_config_id == -1
assert ctx.chat_model_id == -1
assert ctx.image_gen_model_id == 5
assert ctx.vision_model_id == -1
def test_build_action_ctx_none_models_leaves_fields_none() -> None:
@ -54,6 +54,6 @@ def test_build_action_ctx_none_models_leaves_fields_none() -> None:
None,
)
assert ctx.agent_llm_id is None
assert ctx.image_generation_config_id is None
assert ctx.vision_llm_config_id is None
assert ctx.chat_model_id is None
assert ctx.image_gen_model_id is None
assert ctx.vision_model_id is None

View file

@ -40,24 +40,24 @@ def test_automation_definition_models_round_trip() -> None:
name="Daily digest",
plan=[PlanStep(step_id="s1", action="agent_task")],
models=AutomationModels(
agent_llm_id=-1,
image_generation_config_id=5,
vision_llm_config_id=-1,
chat_model_id=-1,
image_gen_model_id=5,
vision_model_id=-1,
),
)
dumped = definition.model_dump(mode="json", by_alias=True)
assert dumped["models"] == {
"agent_llm_id": -1,
"image_generation_config_id": 5,
"vision_llm_config_id": -1,
"chat_model_id": -1,
"image_gen_model_id": 5,
"vision_model_id": -1,
}
restored = AutomationDefinition.model_validate(dumped)
assert restored.models is not None
assert restored.models.agent_llm_id == -1
assert restored.models.image_generation_config_id == 5
assert restored.models.vision_llm_config_id == -1
assert restored.models.chat_model_id == -1
assert restored.models.image_gen_model_id == 5
assert restored.models.vision_model_id == -1
def test_automation_definition_rejects_unknown_top_level_field() -> None:

View file

@ -64,12 +64,12 @@ async def test_assert_models_billable_raises_422_on_violation(
def _raise(_ss):
raise AutomationModelPolicyError(
[{"kind": "llm", "config_id": 0, "reason": "Auto mode"}]
[{"kind": "llm", "model_id": 0, "reason": "Auto mode"}]
)
monkeypatch.setattr(automation_mod, "assert_automation_models_billable", _raise)
service = _service(SimpleNamespace(agent_llm_id=0))
service = _service(SimpleNamespace(chat_model_id=0))
with pytest.raises(HTTPException) as exc_info:
await service._assert_models_billable(1)
@ -99,7 +99,7 @@ async def test_assert_models_billable_returns_search_space_when_ok(
automation_mod, "assert_automation_models_billable", lambda _ss: None
)
search_space = SimpleNamespace(agent_llm_id=-1)
search_space = SimpleNamespace(chat_model_id=-1)
service = _service(search_space)
assert await service._assert_models_billable(1) is search_space
@ -123,9 +123,9 @@ async def test_create_injects_captured_models_from_search_space(
monkeypatch.setattr(AutomationService, "_get_with_triggers_or_raise", _return_added)
search_space = SimpleNamespace(
agent_llm_id=-1,
image_generation_config_id=5,
vision_llm_config_id=-1,
chat_model_id=-1,
image_gen_model_id=5,
vision_model_id=-1,
)
service = _service(search_space)
payload = AutomationCreate(
@ -137,9 +137,9 @@ async def test_create_injects_captured_models_from_search_space(
automation = await service.create(payload)
assert automation.definition["models"] == {
"agent_llm_id": -1,
"image_generation_config_id": 5,
"vision_llm_config_id": -1,
"chat_model_id": -1,
"image_gen_model_id": 5,
"vision_model_id": -1,
}
@ -162,9 +162,9 @@ async def test_create_treats_unset_prefs_as_auto_zero(
monkeypatch.setattr(AutomationService, "_get_with_triggers_or_raise", _return_added)
search_space = SimpleNamespace(
agent_llm_id=None,
image_generation_config_id=None,
vision_llm_config_id=None,
chat_model_id=None,
image_gen_model_id=None,
vision_model_id=None,
)
service = _service(search_space)
payload = AutomationCreate(search_space_id=1, name="A", definition=_definition())
@ -172,9 +172,9 @@ async def test_create_treats_unset_prefs_as_auto_zero(
automation = await service.create(payload)
assert automation.definition["models"] == {
"agent_llm_id": 0,
"image_generation_config_id": 0,
"vision_llm_config_id": 0,
"chat_model_id": 0,
"image_gen_model_id": 0,
"vision_model_id": 0,
}
@ -195,11 +195,11 @@ async def test_create_honors_selected_models_when_provided(
)
validated: dict[str, Any] = {}
def _assert_ok(*, agent_llm_id, image_generation_config_id, vision_llm_config_id):
def _assert_ok(*, chat_model_id, image_gen_model_id, vision_model_id):
validated["ids"] = (
agent_llm_id,
image_generation_config_id,
vision_llm_config_id,
chat_model_id,
image_gen_model_id,
vision_model_id,
)
monkeypatch.setattr(automation_mod, "assert_models_billable", _assert_ok)
@ -213,15 +213,15 @@ async def test_create_honors_selected_models_when_provided(
monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize)
monkeypatch.setattr(AutomationService, "_get_with_triggers_or_raise", _return_added)
service = _service(SimpleNamespace(agent_llm_id=-99))
service = _service(SimpleNamespace(chat_model_id=-99))
payload = AutomationCreate(
search_space_id=1,
name="A",
definition=_definition(
models=AutomationModels(
agent_llm_id=-1,
image_generation_config_id=7,
vision_llm_config_id=-2,
chat_model_id=-1,
image_gen_model_id=7,
vision_model_id=-2,
)
),
)
@ -230,9 +230,9 @@ async def test_create_honors_selected_models_when_provided(
assert validated["ids"] == (-1, 7, -2)
assert automation.definition["models"] == {
"agent_llm_id": -1,
"image_generation_config_id": 7,
"vision_llm_config_id": -2,
"chat_model_id": -1,
"image_gen_model_id": 7,
"vision_model_id": -2,
}
@ -241,9 +241,9 @@ async def test_create_rejects_unbillable_selected_models(
) -> None:
"""A non-billable explicit selection maps the policy error to HTTP 422."""
def _raise(*, agent_llm_id, image_generation_config_id, vision_llm_config_id):
def _raise(*, chat_model_id, image_gen_model_id, vision_model_id):
raise AutomationModelPolicyError(
[{"kind": "llm", "config_id": -3, "reason": "free model"}]
[{"kind": "llm", "model_id": -3, "reason": "free model"}]
)
monkeypatch.setattr(automation_mod, "assert_models_billable", _raise)
@ -253,15 +253,15 @@ async def test_create_rejects_unbillable_selected_models(
monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize)
service = _service(SimpleNamespace(agent_llm_id=-3))
service = _service(SimpleNamespace(chat_model_id=-3))
payload = AutomationCreate(
search_space_id=1,
name="A",
definition=_definition(
models=AutomationModels(
agent_llm_id=-3,
image_generation_config_id=7,
vision_llm_config_id=-2,
chat_model_id=-3,
image_gen_model_id=7,
vision_model_id=-2,
)
),
)
@ -277,9 +277,9 @@ async def test_update_preserves_captured_models(
) -> None:
"""A definition edit carries over the previously captured ``models``."""
captured = {
"agent_llm_id": -1,
"image_generation_config_id": 5,
"vision_llm_config_id": -1,
"chat_model_id": -1,
"image_gen_model_id": 5,
"vision_model_id": -1,
}
existing = SimpleNamespace(
search_space_id=1,
@ -318,20 +318,20 @@ async def test_update_honors_changed_models_when_valid(
"name": "A",
"plan": [],
"models": {
"agent_llm_id": -1,
"image_generation_config_id": 5,
"vision_llm_config_id": -1,
"chat_model_id": -1,
"image_gen_model_id": 5,
"vision_model_id": -1,
},
},
version=3,
)
validated: dict[str, Any] = {}
def _assert_ok(*, agent_llm_id, image_generation_config_id, vision_llm_config_id):
def _assert_ok(*, chat_model_id, image_gen_model_id, vision_model_id):
validated["ids"] = (
agent_llm_id,
image_generation_config_id,
vision_llm_config_id,
chat_model_id,
image_gen_model_id,
vision_model_id,
)
monkeypatch.setattr(automation_mod, "assert_models_billable", _assert_ok)
@ -351,9 +351,9 @@ async def test_update_honors_changed_models_when_valid(
patch = AutomationUpdate(
definition=_definition(
models=AutomationModels(
agent_llm_id=-2,
image_generation_config_id=9,
vision_llm_config_id=-2,
chat_model_id=-2,
image_gen_model_id=9,
vision_model_id=-2,
)
)
)
@ -362,9 +362,9 @@ async def test_update_honors_changed_models_when_valid(
assert validated["ids"] == (-2, 9, -2)
assert result.definition["models"] == {
"agent_llm_id": -2,
"image_generation_config_id": 9,
"vision_llm_config_id": -2,
"chat_model_id": -2,
"image_gen_model_id": 9,
"vision_model_id": -2,
}
assert result.version == 4
@ -379,17 +379,17 @@ async def test_update_rejects_changed_unbillable_models(
"name": "A",
"plan": [],
"models": {
"agent_llm_id": -1,
"image_generation_config_id": 5,
"vision_llm_config_id": -1,
"chat_model_id": -1,
"image_gen_model_id": 5,
"vision_model_id": -1,
},
},
version=3,
)
def _raise(*, agent_llm_id, image_generation_config_id, vision_llm_config_id):
def _raise(*, chat_model_id, image_gen_model_id, vision_model_id):
raise AutomationModelPolicyError(
[{"kind": "llm", "config_id": -7, "reason": "free model"}]
[{"kind": "llm", "model_id": -7, "reason": "free model"}]
)
monkeypatch.setattr(automation_mod, "assert_models_billable", _raise)
@ -409,9 +409,9 @@ async def test_update_rejects_changed_unbillable_models(
patch = AutomationUpdate(
definition=_definition(
models=AutomationModels(
agent_llm_id=-7,
image_generation_config_id=5,
vision_llm_config_id=-1,
chat_model_id=-7,
image_gen_model_id=5,
vision_model_id=-1,
)
)
)
@ -431,9 +431,9 @@ async def test_update_keeps_unchanged_models_without_revalidation(
premium without an unrelated edit tripping the policy check.
"""
captured = {
"agent_llm_id": -1,
"image_generation_config_id": 5,
"vision_llm_config_id": -1,
"chat_model_id": -1,
"image_gen_model_id": 5,
"vision_model_id": -1,
}
existing = SimpleNamespace(
search_space_id=1,
@ -485,7 +485,7 @@ async def test_model_eligibility_authorizes_and_returns_payload(
lambda _ss: {"allowed": False, "violations": [{"kind": "image"}]},
)
service = _service(SimpleNamespace(agent_llm_id=-2))
service = _service(SimpleNamespace(chat_model_id=-2))
result = await service.model_eligibility(search_space_id=5)
assert result == {"allowed": False, "violations": [{"kind": "image"}]}

View file

@ -27,9 +27,9 @@ pytestmark = pytest.mark.unit
def _search_space(*, llm: int | None, image: int | None, vision: int | None):
"""Minimal stand-in for the ``SearchSpace`` ORM row the policy reads."""
return SimpleNamespace(
agent_llm_id=llm,
image_generation_config_id=image,
vision_llm_config_id=vision,
chat_model_id=llm,
image_gen_model_id=image,
vision_model_id=vision,
)
@ -39,29 +39,11 @@ def patched_globals(monkeypatch: pytest.MonkeyPatch):
Negative ids: -1 is premium, -2 is free, for each of llm/image/vision.
"""
llm_configs = {
-1: {"id": -1, "billing_tier": "premium"},
-2: {"id": -2, "billing_tier": "free"},
}
monkeypatch.setattr(
"app.agents.chat.runtime.llm_config.load_global_llm_config_by_id",
lambda cid: llm_configs.get(cid),
)
from app.config import config as app_config
monkeypatch.setattr(
app_config,
"GLOBAL_IMAGE_GEN_CONFIGS",
[
{"id": -1, "billing_tier": "premium"},
{"id": -2, "billing_tier": "free"},
],
raising=False,
)
monkeypatch.setattr(
app_config,
"GLOBAL_VISION_LLM_CONFIGS",
"GLOBAL_MODELS",
[
{"id": -1, "billing_tier": "premium"},
{"id": -2, "billing_tier": "free"},
@ -71,7 +53,7 @@ def patched_globals(monkeypatch: pytest.MonkeyPatch):
return None
@pytest.mark.parametrize("kind", ["llm", "image", "vision"])
@pytest.mark.parametrize("kind", ["chat", "image", "vision"])
def test_byok_positive_id_is_allowed(kind: str, patched_globals) -> None:
"""A positive config id is a user-owned BYOK model — always billable."""
allowed, reason = model_policy._classify(kind, 7)
@ -79,7 +61,7 @@ def test_byok_positive_id_is_allowed(kind: str, patched_globals) -> None:
assert reason == ""
@pytest.mark.parametrize("kind", ["llm", "image", "vision"])
@pytest.mark.parametrize("kind", ["chat", "image", "vision"])
@pytest.mark.parametrize("config_id", [0, None])
def test_auto_mode_is_blocked(kind: str, config_id, patched_globals) -> None:
"""Auto mode (id 0) and an unset slot (None) are blocked."""
@ -88,7 +70,7 @@ def test_auto_mode_is_blocked(kind: str, config_id, patched_globals) -> None:
assert "Auto mode" in reason
@pytest.mark.parametrize("kind", ["llm", "image", "vision"])
@pytest.mark.parametrize("kind", ["chat", "image", "vision"])
def test_premium_global_is_allowed(kind: str, patched_globals) -> None:
"""A negative (global) id with premium billing tier is allowed."""
allowed, reason = model_policy._classify(kind, -1)
@ -96,7 +78,7 @@ def test_premium_global_is_allowed(kind: str, patched_globals) -> None:
assert reason == ""
@pytest.mark.parametrize("kind", ["llm", "image", "vision"])
@pytest.mark.parametrize("kind", ["chat", "image", "vision"])
def test_free_global_is_blocked(kind: str, patched_globals) -> None:
"""A negative (global) id with a free billing tier is blocked."""
allowed, reason = model_policy._classify(kind, -2)
@ -104,7 +86,7 @@ def test_free_global_is_blocked(kind: str, patched_globals) -> None:
assert "free model" in reason
@pytest.mark.parametrize("kind", ["llm", "image", "vision"])
@pytest.mark.parametrize("kind", ["chat", "image", "vision"])
def test_unknown_global_id_is_blocked(kind: str, patched_globals) -> None:
"""A negative id that resolves to no config is treated as not premium."""
allowed, _ = model_policy._classify(kind, -999)
@ -125,10 +107,10 @@ def test_eligibility_reports_each_violation(patched_globals) -> None:
assert result["allowed"] is False
kinds = {v["kind"] for v in result["violations"]}
assert kinds == {"llm", "image", "vision"}
# config_id is echoed back for the UI / settings deep-link.
by_kind = {v["kind"]: v["config_id"] for v in result["violations"]}
assert by_kind == {"llm": -2, "image": 0, "vision": -2}
assert kinds == {"chat", "image", "vision"}
# model_id is echoed back for the UI / settings deep-link.
by_kind = {v["kind"]: v["model_id"] for v in result["violations"]}
assert by_kind == {"chat": -2, "image": 0, "vision": -2}
def test_assert_raises_with_violations(patched_globals) -> None:
@ -138,7 +120,7 @@ def test_assert_raises_with_violations(patched_globals) -> None:
assert_automation_models_billable(search_space)
assert len(exc_info.value.violations) == 1
assert exc_info.value.violations[0]["kind"] == "llm"
assert exc_info.value.violations[0]["kind"] == "chat"
def test_assert_passes_when_all_billable(patched_globals) -> None:
@ -153,7 +135,7 @@ def test_assert_passes_when_all_billable(patched_globals) -> None:
def test_get_model_eligibility_all_billable(patched_globals) -> None:
"""Premium LLM + BYOK image + premium vision (explicit ids) → allowed."""
result = get_model_eligibility(
agent_llm_id=-1, image_generation_config_id=5, vision_llm_config_id=-1
chat_model_id=-1, image_gen_model_id=5, vision_model_id=-1
)
assert result == {"allowed": True, "violations": []}
@ -161,28 +143,28 @@ def test_get_model_eligibility_all_billable(patched_globals) -> None:
def test_get_model_eligibility_reports_each_violation(patched_globals) -> None:
"""Free LLM, Auto image, free vision (explicit ids) each produce a violation."""
result = get_model_eligibility(
agent_llm_id=-2, image_generation_config_id=0, vision_llm_config_id=-2
chat_model_id=-2, image_gen_model_id=0, vision_model_id=-2
)
assert result["allowed"] is False
by_kind = {v["kind"]: v["config_id"] for v in result["violations"]}
assert by_kind == {"llm": -2, "image": 0, "vision": -2}
by_kind = {v["kind"]: v["model_id"] for v in result["violations"]}
assert by_kind == {"chat": -2, "image": 0, "vision": -2}
def test_assert_models_billable_raises(patched_globals) -> None:
"""``assert_models_billable`` raises when any explicit id is blocked."""
with pytest.raises(AutomationModelPolicyError) as exc_info:
assert_models_billable(
agent_llm_id=0, image_generation_config_id=5, vision_llm_config_id=-1
chat_model_id=0, image_gen_model_id=5, vision_model_id=-1
)
assert len(exc_info.value.violations) == 1
assert exc_info.value.violations[0]["kind"] == "llm"
assert exc_info.value.violations[0]["kind"] == "chat"
def test_assert_models_billable_passes(patched_globals) -> None:
"""No exception when every explicit id is premium or BYOK."""
assert (
assert_models_billable(
agent_llm_id=3, image_generation_config_id=-1, vision_llm_config_id=4
chat_model_id=3, image_gen_model_id=-1, vision_model_id=4
)
is None
)
@ -192,5 +174,5 @@ def test_search_space_wrapper_delegates_to_core(patched_globals) -> None:
"""The search-space wrapper produces the same result as the ID core."""
search_space = _search_space(llm=-2, image=0, vision=-2)
assert get_automation_model_eligibility(search_space) == get_model_eligibility(
agent_llm_id=-2, image_generation_config_id=0, vision_llm_config_id=-2
chat_model_id=-2, image_gen_model_id=0, vision_model_id=-2
)

View file

@ -1,110 +0,0 @@
"""Unit tests for ``supports_image_input`` derivation on BYOK chat config
endpoints (``GET /new-llm-configs`` list, ``GET /new-llm-configs/{id}``).
There is no DB column for ``supports_image_input`` on
``NewLLMConfig`` the value is resolved at the API boundary by
``derive_supports_image_input`` so the new-chat selector / streaming
task can read the same field shape regardless of source (BYOK vs YAML
vs OpenRouter dynamic). Default-allow on unknown so we don't lock the
user out of their own model choice.
"""
from __future__ import annotations
from datetime import UTC, datetime
from types import SimpleNamespace
from uuid import uuid4
import pytest
from app.db import LiteLLMProvider
from app.routes import new_llm_config_routes
pytestmark = pytest.mark.unit
def _byok_row(
*,
id_: int,
model_name: str,
base_model: str | None = None,
provider: LiteLLMProvider = LiteLLMProvider.OPENAI,
custom_provider: str | None = None,
) -> object:
"""Mimic the SQLAlchemy row's attribute surface; ``model_validate``
walks ``from_attributes=True`` so a ``SimpleNamespace`` is enough.
``provider`` is a real ``LiteLLMProvider`` enum value so Pydantic's
enum validator accepts it same as the ORM row would carry."""
return SimpleNamespace(
id=id_,
name=f"BYOK-{id_}",
description=None,
provider=provider,
custom_provider=custom_provider,
model_name=model_name,
api_key="sk-byok",
api_base=None,
litellm_params={"base_model": base_model} if base_model else None,
system_instructions="",
use_default_system_instructions=True,
citations_enabled=True,
created_at=datetime.now(tz=UTC),
search_space_id=42,
user_id=uuid4(),
)
def test_serialize_byok_known_vision_model_resolves_true():
"""The catalog resolver consults LiteLLM's map for ``gpt-4o`` ->
True. The serialized row carries that value through to the
``NewLLMConfigRead`` schema."""
row = _byok_row(id_=1, model_name="gpt-4o")
serialized = new_llm_config_routes._serialize_byok_config(row)
assert serialized.supports_image_input is True
assert serialized.id == 1
assert serialized.model_name == "gpt-4o"
def test_serialize_byok_unknown_model_default_allows():
"""Unknown / unmapped: default-allow. The streaming-task safety net
is the actual block, and it requires LiteLLM to *explicitly* say
text-only so a brand new BYOK model should not be pre-judged."""
row = _byok_row(
id_=2,
model_name="brand-new-model-x9-unmapped",
provider=LiteLLMProvider.CUSTOM,
custom_provider="brand_new_proxy",
)
serialized = new_llm_config_routes._serialize_byok_config(row)
assert serialized.supports_image_input is True
def test_serialize_byok_uses_base_model_when_present():
"""Azure-style: ``model_name`` is the deployment id, ``base_model``
inside ``litellm_params`` is the canonical sku LiteLLM knows. The
helper must consult ``base_model`` first or unrecognised deployment
ids would shadow the real capability."""
row = _byok_row(
id_=3,
model_name="my-azure-deployment-id-no-litellm-knows-this",
base_model="gpt-4o",
provider=LiteLLMProvider.AZURE_OPENAI,
)
serialized = new_llm_config_routes._serialize_byok_config(row)
assert serialized.supports_image_input is True
def test_serialize_byok_returns_pydantic_read_model():
"""The route now returns ``NewLLMConfigRead`` (not the raw ORM) so
the schema additions are guaranteed to be present in the API
surface. This guards against a future regression where someone
deletes the augmentation step and falls back to ORM passthrough."""
from app.schemas import NewLLMConfigRead
row = _byok_row(id_=4, model_name="gpt-4o")
serialized = new_llm_config_routes._serialize_byok_config(row)
assert isinstance(serialized, NewLLMConfigRead)

View file

@ -1,184 +0,0 @@
"""Unit tests for ``is_premium`` derivation on the global image-gen and
vision-LLM list endpoints.
Chat globals (``GET /global-llm-configs``) already emit
``is_premium = (billing_tier == "premium")``. Image and vision did not,
which made the new-chat ``model-selector`` render the Free/Premium badge
on the Chat tab but skip it on the Image and Vision tabs (the selector
keys its badge logic off ``is_premium``). These tests pin parity:
* YAML free entry ``is_premium=False``
* YAML premium entry ``is_premium=True``
* OpenRouter dynamic premium entry ``is_premium=True``
* Auto stub (always emitted when at least one config is present)
``is_premium=False``
"""
from __future__ import annotations
import pytest
pytestmark = pytest.mark.unit
_IMAGE_FIXTURE: list[dict] = [
{
"id": -1,
"name": "DALL-E 3",
"provider": "OPENAI",
"model_name": "dall-e-3",
"api_key": "sk-test",
"billing_tier": "free",
},
{
"id": -2,
"name": "GPT-Image 1 (premium)",
"provider": "OPENAI",
"model_name": "gpt-image-1",
"api_key": "sk-test",
"billing_tier": "premium",
},
{
"id": -20_001,
"name": "google/gemini-2.5-flash-image (OpenRouter)",
"provider": "OPENROUTER",
"model_name": "google/gemini-2.5-flash-image",
"api_key": "sk-or-test",
"api_base": "https://openrouter.ai/api/v1",
"billing_tier": "premium",
},
]
_VISION_FIXTURE: list[dict] = [
{
"id": -1,
"name": "GPT-4o Vision",
"provider": "OPENAI",
"model_name": "gpt-4o",
"api_key": "sk-test",
"billing_tier": "free",
},
{
"id": -2,
"name": "Claude 3.5 Sonnet (premium)",
"provider": "ANTHROPIC",
"model_name": "claude-3-5-sonnet",
"api_key": "sk-ant-test",
"billing_tier": "premium",
},
{
"id": -30_001,
"name": "openai/gpt-4o (OpenRouter)",
"provider": "OPENROUTER",
"model_name": "openai/gpt-4o",
"api_key": "sk-or-test",
"api_base": "https://openrouter.ai/api/v1",
"billing_tier": "premium",
},
]
# =============================================================================
# Image generation
# =============================================================================
@pytest.mark.asyncio
async def test_global_image_gen_configs_emit_is_premium(monkeypatch):
"""Each emitted config must carry ``is_premium`` derived server-side
from ``billing_tier``. The Auto stub is always free.
"""
from app.config import config
from app.routes import image_generation_routes
monkeypatch.setattr(
config, "GLOBAL_IMAGE_GEN_CONFIGS", _IMAGE_FIXTURE, raising=False
)
payload = await image_generation_routes.get_global_image_gen_configs(user=None)
by_id = {c["id"]: c for c in payload}
# Auto stub is always emitted when at least one global config exists,
# and it must always declare itself free (Auto-mode billing-tier
# surfacing is a separate follow-up).
assert 0 in by_id, "Auto stub should be emitted when at least one config exists"
assert by_id[0]["is_premium"] is False
assert by_id[0]["billing_tier"] == "free"
# YAML free entry — ``is_premium=False``
assert by_id[-1]["is_premium"] is False
assert by_id[-1]["billing_tier"] == "free"
# YAML premium entry — ``is_premium=True``
assert by_id[-2]["is_premium"] is True
assert by_id[-2]["billing_tier"] == "premium"
# OpenRouter dynamic premium entry — same field, same derivation
assert by_id[-20_001]["is_premium"] is True
assert by_id[-20_001]["billing_tier"] == "premium"
# Every emitted dict (including Auto) must have the field — never missing.
for cfg in payload:
assert "is_premium" in cfg, f"is_premium missing from {cfg.get('id')}"
assert isinstance(cfg["is_premium"], bool)
@pytest.mark.asyncio
async def test_global_image_gen_configs_no_globals_no_auto_stub(monkeypatch):
"""When there are no global configs at all, the endpoint emits an
empty list (no Auto stub) Auto mode would have nothing to route to.
"""
from app.config import config
from app.routes import image_generation_routes
monkeypatch.setattr(config, "GLOBAL_IMAGE_GEN_CONFIGS", [], raising=False)
payload = await image_generation_routes.get_global_image_gen_configs(user=None)
assert payload == []
# =============================================================================
# Vision LLM
# =============================================================================
@pytest.mark.asyncio
async def test_global_vision_llm_configs_emit_is_premium(monkeypatch):
from app.config import config
from app.routes import vision_llm_routes
monkeypatch.setattr(
config, "GLOBAL_VISION_LLM_CONFIGS", _VISION_FIXTURE, raising=False
)
payload = await vision_llm_routes.get_global_vision_llm_configs(user=None)
by_id = {c["id"]: c for c in payload}
assert 0 in by_id, "Auto stub should be emitted when at least one config exists"
assert by_id[0]["is_premium"] is False
assert by_id[0]["billing_tier"] == "free"
assert by_id[-1]["is_premium"] is False
assert by_id[-1]["billing_tier"] == "free"
assert by_id[-2]["is_premium"] is True
assert by_id[-2]["billing_tier"] == "premium"
assert by_id[-30_001]["is_premium"] is True
assert by_id[-30_001]["billing_tier"] == "premium"
for cfg in payload:
assert "is_premium" in cfg, f"is_premium missing from {cfg.get('id')}"
assert isinstance(cfg["is_premium"], bool)
@pytest.mark.asyncio
async def test_global_vision_llm_configs_no_globals_no_auto_stub(monkeypatch):
from app.config import config
from app.routes import vision_llm_routes
monkeypatch.setattr(config, "GLOBAL_VISION_LLM_CONFIGS", [], raising=False)
payload = await vision_llm_routes.get_global_vision_llm_configs(user=None)
assert payload == []

View file

@ -1,106 +0,0 @@
"""Unit tests for ``supports_image_input`` derivation on the chat global
config endpoint (``GET /global-new-llm-configs``).
Resolution order (matches ``new_llm_config_routes.get_global_new_llm_configs``):
1. Explicit ``supports_image_input`` on the cfg dict (set by the YAML
loader for operator overrides, or by the OpenRouter integration from
``architecture.input_modalities``) wins.
2. ``derive_supports_image_input`` helper default-allow on unknown
models, only False when LiteLLM / OR modalities are definitive.
The flag is purely informational at the API boundary. The streaming
task safety net (``is_known_text_only_chat_model``) is the actual block,
and it requires LiteLLM to *explicitly* mark the model as text-only.
"""
from __future__ import annotations
import pytest
pytestmark = pytest.mark.unit
_FIXTURE: list[dict] = [
{
"id": -1,
"name": "GPT-4o (explicit true)",
"description": "vision-capable, explicit YAML override",
"provider": "OPENAI",
"model_name": "gpt-4o",
"api_key": "sk-test",
"billing_tier": "free",
"supports_image_input": True,
},
{
"id": -2,
"name": "DeepSeek V3 (explicit false)",
"description": "OpenRouter dynamic — modality-derived false",
"provider": "OPENROUTER",
"model_name": "deepseek/deepseek-v3.2-exp",
"api_key": "sk-or-test",
"api_base": "https://openrouter.ai/api/v1",
"billing_tier": "free",
"supports_image_input": False,
},
{
"id": -10_010,
"name": "Unannotated GPT-4o",
"description": "no flag set — resolver should derive True via LiteLLM",
"provider": "OPENAI",
"model_name": "gpt-4o",
"api_key": "sk-test",
"billing_tier": "free",
# supports_image_input intentionally absent
},
{
"id": -10_011,
"name": "Unannotated unknown model",
"description": "unmapped — default-allow True",
"provider": "CUSTOM",
"custom_provider": "brand_new_proxy",
"model_name": "brand-new-model-x9",
"api_key": "sk-test",
"billing_tier": "free",
},
]
@pytest.mark.asyncio
async def test_global_new_llm_configs_emit_supports_image_input(monkeypatch):
"""Each emitted chat config carries ``supports_image_input`` as a
bool. Explicit values win; unannotated entries are resolved via the
helper (default-allow True)."""
from app.config import config
from app.routes import new_llm_config_routes
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", _FIXTURE, raising=False)
payload = await new_llm_config_routes.get_global_new_llm_configs(user=None)
by_id = {c["id"]: c for c in payload}
# Auto stub: optimistic True so the user can keep Auto selected with
# vision-capable deployments somewhere in the pool.
assert 0 in by_id, "Auto stub should be emitted when configs exist"
assert by_id[0]["supports_image_input"] is True
assert by_id[0]["is_auto_mode"] is True
# Explicit True is preserved.
assert by_id[-1]["supports_image_input"] is True
# Explicit False is preserved (the exact failure mode the safety net
# guards against — DeepSeek V3 over OpenRouter would 404 with "No
# endpoints found that support image input").
assert by_id[-2]["supports_image_input"] is False
# Unannotated GPT-4o: resolver consults LiteLLM, which says vision.
assert by_id[-10_010]["supports_image_input"] is True
# Unknown / unmapped model: default-allow rather than pre-judge.
assert by_id[-10_011]["supports_image_input"] is True
for cfg in payload:
assert "supports_image_input" in cfg, (
f"supports_image_input missing from {cfg.get('id')}"
)
assert isinstance(cfg["supports_image_input"], bool)

View file

@ -27,9 +27,18 @@ async def test_resolve_billing_for_auto_mode(monkeypatch):
from app.routes import image_generation_routes
from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS
search_space = SimpleNamespace(image_generation_config_id=None)
async def _no_auto_candidates(*_args, **_kwargs):
return []
monkeypatch.setattr(
image_generation_routes,
"auto_model_candidates",
_no_auto_candidates,
)
search_space = SimpleNamespace(id=1, user_id=None, image_gen_model_id=None)
tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
session=None, # Not consumed on this code path.
session=None,
config_id=0, # IMAGE_GEN_AUTO_MODE_ID
search_space=search_space,
)
@ -45,26 +54,48 @@ async def test_resolve_billing_for_premium_global_config(monkeypatch):
monkeypatch.setattr(
config,
"GLOBAL_IMAGE_GEN_CONFIGS",
"GLOBAL_MODELS",
[
{
"id": -1,
"provider": "OPENAI",
"model_name": "gpt-image-1",
"connection_id": -101,
"model_id": "gpt-image-1",
"billing_tier": "premium",
"quota_reserve_micros": 75_000,
"catalog": {"quota_reserve_micros": 75_000},
},
{
"id": -2,
"provider": "OPENROUTER",
"model_name": "google/gemini-2.5-flash-image",
"connection_id": -102,
"model_id": "google/gemini-2.5-flash-image",
"billing_tier": "free",
"catalog": {},
},
],
raising=False,
)
monkeypatch.setattr(
config,
"GLOBAL_CONNECTIONS",
[
{
"id": -101,
"provider": "openai",
"api_key": "sk-test",
"base_url": None,
"extra": {},
},
{
"id": -102,
"provider": "openrouter",
"api_key": "sk-or-test",
"base_url": "https://openrouter.ai/api/v1",
"extra": {},
},
],
raising=False,
)
search_space = SimpleNamespace(image_generation_config_id=None)
search_space = SimpleNamespace(id=1, user_id=None, image_gen_model_id=None)
# Premium with override.
tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
@ -94,7 +125,7 @@ async def test_resolve_billing_for_user_owned_byok_is_free():
from app.routes import image_generation_routes
from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS
search_space = SimpleNamespace(image_generation_config_id=None)
search_space = SimpleNamespace(id=1, user_id=None, image_gen_model_id=None)
tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
session=None, config_id=42, search_space=search_space
)
@ -105,7 +136,7 @@ async def test_resolve_billing_for_user_owned_byok_is_free():
@pytest.mark.asyncio
async def test_resolve_billing_falls_back_to_search_space_default(monkeypatch):
"""When the request omits ``image_generation_config_id``, the helper
"""When the request omits ``image_gen_model_id``, the helper
must consult the search space's default — so a search space pinned
to a premium global config still gates new requests by quota.
"""
@ -114,19 +145,34 @@ async def test_resolve_billing_falls_back_to_search_space_default(monkeypatch):
monkeypatch.setattr(
config,
"GLOBAL_IMAGE_GEN_CONFIGS",
"GLOBAL_MODELS",
[
{
"id": -7,
"provider": "OPENAI",
"model_name": "gpt-image-1",
"connection_id": -101,
"model_id": "gpt-image-1",
"billing_tier": "premium",
"catalog": {},
}
],
raising=False,
)
monkeypatch.setattr(
config,
"GLOBAL_CONNECTIONS",
[
{
"id": -101,
"provider": "openai",
"api_key": "sk-test",
"base_url": None,
"extra": {},
}
],
raising=False,
)
search_space = SimpleNamespace(image_generation_config_id=-7)
search_space = SimpleNamespace(id=1, user_id=None, image_gen_model_id=-7)
(
tier,
model,

View file

@ -1,27 +1,4 @@
"""Unit tests for ``_resolve_agent_billing_for_search_space``.
Validates the resolver used by Celery podcast/video tasks to compute
``(owner_user_id, billing_tier, base_model)`` from a search space and its
agent LLM config. The resolver mirrors chat's billing-resolution pattern at
``stream_new_chat.py:2294-2351`` and is the single integration point that
prevents Auto-mode podcast/video from leaking premium credit.
Coverage:
* Auto mode + ``thread_id`` set, pin resolves to a negative-id premium
global returns ``("premium", <base_model>)``.
* Auto mode + ``thread_id`` set, pin resolves to a negative-id free
global returns ``("free", <base_model>)``.
* Auto mode + ``thread_id`` set, pin resolves to a positive-id BYOK config
always ``"free"``.
* Auto mode + ``thread_id=None`` fallback to ``("free", "auto")`` without
hitting the pin service.
* Negative id (no Auto) uses ``get_global_llm_config``'s
``billing_tier``.
* Positive id (user BYOK) always ``"free"``.
* Search space not found raises ``ValueError``.
* ``agent_llm_id`` is None raises ``ValueError``.
"""
"""Unit tests for ``_resolve_agent_billing_for_search_space``."""
from __future__ import annotations
@ -34,11 +11,6 @@ import pytest
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# Fakes
# ---------------------------------------------------------------------------
class _FakeExecResult:
def __init__(self, obj):
self._obj = obj
@ -51,14 +23,6 @@ class _FakeExecResult:
class _FakeSession:
"""Tiny AsyncSession stub.
``responses`` is a list of objects to return from successive
``execute()`` calls (in order). The resolver makes at most two
``execute()`` calls (search-space lookup, then optionally NewLLMConfig
lookup), so two queued responses cover the matrix.
"""
def __init__(self, responses: list):
self._responses = list(responses)
@ -67,9 +31,6 @@ class _FakeSession:
return _FakeExecResult(None)
return _FakeExecResult(self._responses.pop(0))
async def commit(self) -> None:
pass
@dataclass
class _FakePinResolution:
@ -78,53 +39,33 @@ class _FakePinResolution:
from_existing_pin: bool = False
def _make_search_space(*, agent_llm_id: int | None, user_id: UUID) -> SimpleNamespace:
return SimpleNamespace(
id=42,
agent_llm_id=agent_llm_id,
user_id=user_id,
)
def _make_search_space(*, chat_model_id: int | None, user_id: UUID) -> SimpleNamespace:
return SimpleNamespace(id=42, chat_model_id=chat_model_id, user_id=user_id)
def _make_byok_config(
*, id_: int, base_model: str | None = None, model_name: str = "gpt-byok"
def _make_byok_model(
*, id_: int, base_model: str | None = None, model_id: str = "gpt-byok"
) -> SimpleNamespace:
return SimpleNamespace(
id=id_,
model_name=model_name,
litellm_params={"base_model": base_model} if base_model else {},
model_id=model_id,
catalog={"base_model": base_model} if base_model else {},
connection=SimpleNamespace(enabled=True, search_space_id=42, user_id=None),
)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_auto_mode_with_thread_id_resolves_to_premium_global(monkeypatch):
"""Auto + thread → pin service resolves to negative-id premium config →
resolver returns ``("premium", <base_model>)``."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
session = _FakeSession([_make_search_space(chat_model_id=0, user_id=user_id)])
# Mock the pin service to return a concrete premium config id.
async def _fake_resolve_pin(
sess,
*,
thread_id,
search_space_id,
user_id,
selected_llm_config_id,
force_repin_free=False,
):
assert selected_llm_config_id == 0
assert thread_id == 99
async def _fake_resolve_pin(*_args, **kwargs):
assert kwargs["selected_llm_config_id"] == 0
assert kwargs["thread_id"] == 99
return _FakePinResolution(resolved_llm_config_id=-1, resolved_tier="premium")
# Mock global config lookup to return a premium entry.
def _fake_get_global(cfg_id):
if cfg_id == -1:
return {
@ -135,8 +76,6 @@ async def test_auto_mode_with_thread_id_resolves_to_premium_global(monkeypatch):
}
return None
# Lazy imports inside the resolver — patch the *target* modules so the
# imported names resolve to our fakes.
import app.services.auto_model_pin_service as pin_module
import app.services.llm_service as llm_module
@ -154,77 +93,18 @@ async def test_auto_mode_with_thread_id_resolves_to_premium_global(monkeypatch):
assert base_model == "gpt-5.4"
@pytest.mark.asyncio
async def test_auto_mode_with_thread_id_resolves_to_free_global(monkeypatch):
"""Auto + thread → pin returns negative-id free config → resolver
returns ``("free", <base_model>)``. Same path the pin service takes for
out-of-credit users (graceful degradation)."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
async def _fake_resolve_pin(
sess,
*,
thread_id,
search_space_id,
user_id,
selected_llm_config_id,
force_repin_free=False,
):
return _FakePinResolution(resolved_llm_config_id=-3, resolved_tier="free")
def _fake_get_global(cfg_id):
if cfg_id == -3:
return {
"id": -3,
"model_name": "openrouter/free-model",
"billing_tier": "free",
"litellm_params": {"base_model": "openrouter/free-model"},
}
return None
import app.services.auto_model_pin_service as pin_module
import app.services.llm_service as llm_module
monkeypatch.setattr(
pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
)
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42, thread_id=99
)
assert owner == user_id
assert tier == "free"
assert base_model == "openrouter/free-model"
@pytest.mark.asyncio
async def test_auto_mode_with_thread_id_resolves_to_byok_is_free(monkeypatch):
"""Auto + thread → pin returns positive-id BYOK config → resolver
returns ``("free", ...)`` (BYOK is always free per
``AgentConfig.from_new_llm_config``)."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
search_space = _make_search_space(agent_llm_id=0, user_id=user_id)
byok_cfg = _make_byok_config(
id_=17, base_model="anthropic/claude-3-haiku", model_name="my-claude"
search_space = _make_search_space(chat_model_id=0, user_id=user_id)
byok_model = _make_byok_model(
id_=17, base_model="anthropic/claude-3-haiku", model_id="my-claude"
)
session = _FakeSession([search_space, byok_cfg])
session = _FakeSession([search_space, byok_model])
async def _fake_resolve_pin(
sess,
*,
thread_id,
search_space_id,
user_id,
selected_llm_config_id,
force_repin_free=False,
):
async def _fake_resolve_pin(*_args, **_kwargs):
return _FakePinResolution(resolved_llm_config_id=17, resolved_tier="free")
import app.services.auto_model_pin_service as pin_module
@ -244,13 +124,10 @@ async def test_auto_mode_with_thread_id_resolves_to_byok_is_free(monkeypatch):
@pytest.mark.asyncio
async def test_auto_mode_without_thread_id_falls_back_to_free():
"""Auto + ``thread_id=None`` → ``("free", "auto")`` without invoking
the pin service. Forward-compat fallback for any future direct-API
entrypoint that doesn't have a chat thread."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
session = _FakeSession([_make_search_space(chat_model_id=0, user_id=user_id)])
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42, thread_id=None
@ -263,13 +140,10 @@ async def test_auto_mode_without_thread_id_falls_back_to_free():
@pytest.mark.asyncio
async def test_auto_mode_pin_failure_falls_back_to_free(monkeypatch):
"""If the pin service raises ``ValueError`` (thread missing /
mismatched search space), the resolver should log and return free
rather than killing the whole task."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
session = _FakeSession([_make_search_space(chat_model_id=0, user_id=user_id)])
async def _fake_resolve_pin(*args, **kwargs):
raise ValueError("thread missing")
@ -291,12 +165,10 @@ async def test_auto_mode_pin_failure_falls_back_to_free(monkeypatch):
@pytest.mark.asyncio
async def test_negative_id_premium_global_returns_premium(monkeypatch):
"""Explicit negative agent_llm_id → ``get_global_llm_config`` →
return its ``billing_tier``."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=-1, user_id=user_id)])
session = _FakeSession([_make_search_space(chat_model_id=-1, user_id=user_id)])
def _fake_get_global(cfg_id):
return {
@ -319,50 +191,15 @@ async def test_negative_id_premium_global_returns_premium(monkeypatch):
assert base_model == "gpt-5.4"
@pytest.mark.asyncio
async def test_negative_id_free_global_returns_free(monkeypatch):
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=-2, user_id=user_id)])
def _fake_get_global(cfg_id):
return {
"id": cfg_id,
"model_name": "openrouter/some-free",
"billing_tier": "free",
"litellm_params": {"base_model": "openrouter/some-free"},
}
import app.services.llm_service as llm_module
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42, thread_id=None
)
assert owner == user_id
assert tier == "free"
assert base_model == "openrouter/some-free"
@pytest.mark.asyncio
async def test_negative_id_missing_base_model_falls_back_to_model_name(monkeypatch):
"""When the global config has no ``litellm_params.base_model``, the
resolver falls back to ``model_name`` matching chat's behavior."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=-5, user_id=user_id)])
session = _FakeSession([_make_search_space(chat_model_id=-5, user_id=user_id)])
def _fake_get_global(cfg_id):
return {
"id": cfg_id,
"model_name": "fallback-model",
"billing_tier": "premium",
# No litellm_params.
}
return {"id": cfg_id, "model_name": "fallback-model", "billing_tier": "premium"}
import app.services.llm_service as llm_module
@ -378,14 +215,12 @@ async def test_negative_id_missing_base_model_falls_back_to_model_name(monkeypat
@pytest.mark.asyncio
async def test_positive_id_byok_is_always_free():
"""Positive agent_llm_id → user-owned BYOK NewLLMConfig → always free,
regardless of underlying provider tier."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
search_space = _make_search_space(agent_llm_id=23, user_id=user_id)
byok_cfg = _make_byok_config(id_=23, base_model="anthropic/claude-3.5-sonnet")
session = _FakeSession([search_space, byok_cfg])
search_space = _make_search_space(chat_model_id=23, user_id=user_id)
byok_model = _make_byok_model(id_=23, base_model="anthropic/claude-3.5-sonnet")
session = _FakeSession([search_space, byok_model])
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42
@ -398,13 +233,10 @@ async def test_positive_id_byok_is_always_free():
@pytest.mark.asyncio
async def test_positive_id_byok_missing_returns_free_with_empty_base_model():
"""If the BYOK config row is missing/deleted but the search space still
points at it, the resolver still returns free (no debit) with an empty
base_model billable_call's premium path is skipped, no harm done."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=99, user_id=user_id)])
session = _FakeSession([_make_search_space(chat_model_id=99, user_id=user_id)])
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42
@ -419,18 +251,18 @@ async def test_positive_id_byok_missing_returns_free_with_empty_base_model():
async def test_search_space_not_found_raises_value_error():
from app.services.billable_calls import _resolve_agent_billing_for_search_space
session = _FakeSession([None])
with pytest.raises(ValueError, match="Search space"):
await _resolve_agent_billing_for_search_space(session, search_space_id=999)
await _resolve_agent_billing_for_search_space(
_FakeSession([None]), search_space_id=999
)
@pytest.mark.asyncio
async def test_agent_llm_id_none_raises_value_error():
async def test_chat_model_id_none_raises_value_error():
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=None, user_id=user_id)])
session = _FakeSession([_make_search_space(chat_model_id=None, user_id=user_id)])
with pytest.raises(ValueError, match="agent_llm_id"):
with pytest.raises(ValueError, match="chat_model_id"):
await _resolve_agent_billing_for_search_space(session, search_space_id=42)

View file

@ -17,8 +17,39 @@ from app.services.auto_model_pin_service import (
pytestmark = pytest.mark.unit
class _FakeRedis:
def __init__(self):
self.values: dict[str, str] = {}
self.ttls: dict[str, int] = {}
def set(self, key: str, value: str, *, ex: int | None = None):
self.values[key] = value
if ex is not None:
self.ttls[key] = ex
return True
def mget(self, keys: list[str]):
return [self.values.get(key) for key in keys]
def delete(self, *keys: str):
removed = 0
for key in keys:
if key in self.values:
removed += 1
self.values.pop(key, None)
self.ttls.pop(key, None)
return removed
def scan_iter(self, pattern: str):
prefix = pattern.removesuffix("*")
return (key for key in list(self.values) if key.startswith(prefix))
@pytest.fixture(autouse=True)
def _clear_runtime_cooldown_map():
def _clear_runtime_cooldown_map(monkeypatch):
import app.services.auto_model_pin_service as svc
monkeypatch.setattr(svc, "_runtime_cooldown_redis", _FakeRedis())
clear_runtime_cooldown()
clear_healthy()
yield
@ -32,8 +63,9 @@ class _FakeQuotaResult:
class _FakeExecResult:
def __init__(self, thread):
def __init__(self, *, thread=None, scalars=None):
self._thread = thread
self._scalars = scalars or []
def unique(self):
return self
@ -41,19 +73,71 @@ class _FakeExecResult:
def scalar_one_or_none(self):
return self._thread
def scalars(self):
return SimpleNamespace(all=lambda: self._scalars)
class _FakeSession:
def __init__(self, thread):
def __init__(self, thread, *, models=None):
self.thread = thread
self.models = models or []
self.commit_count = 0
self.execute_count = 0
async def execute(self, _stmt):
return _FakeExecResult(self.thread)
self.execute_count += 1
if self.execute_count == 1:
return _FakeExecResult(thread=self.thread)
return _FakeExecResult(scalars=self.models)
async def commit(self):
self.commit_count += 1
def _set_global_llm_configs(monkeypatch, config, configs: list[dict]):
"""Patch the new global model catalog shape from compact legacy cfg fixtures."""
connections = []
models = []
for cfg in configs:
config_id = int(cfg["id"])
connection_id = config_id - 100_000
provider = cfg.get("provider") or cfg.get("litellm_provider")
model_name = cfg["model_name"]
connections.append(
{
"id": connection_id,
"provider": provider,
"scope": "GLOBAL",
"enabled": True,
}
)
models.append(
{
"id": config_id,
"connection_id": connection_id,
"model_id": model_name,
"display_name": cfg.get("name") or model_name,
"supports_chat": cfg.get("supports_chat", True),
"supports_image_input": cfg.get("supports_image_input", True),
"supports_tools": cfg.get("supports_tools", True),
"supports_image_generation": cfg.get(
"supports_image_generation", False
),
"capabilities_override": cfg.get("capabilities_override") or {},
"billing_tier": cfg.get("billing_tier", "free"),
"catalog": {
"auto_pin_tier": cfg.get("auto_pin_tier"),
"quality_score": cfg.get("quality_score")
or cfg.get("quality_score_static"),
},
}
)
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", configs)
monkeypatch.setattr(config, "GLOBAL_CONNECTIONS", connections)
monkeypatch.setattr(config, "GLOBAL_MODELS", models)
def _thread(
*,
search_space_id: int = 10,
@ -71,14 +155,19 @@ async def test_auto_first_turn_pins_one_model(monkeypatch):
from app.config import config
session = _FakeSession(_thread())
monkeypatch.setattr(
_set_global_llm_configs(
monkeypatch,
config,
"GLOBAL_LLM_CONFIGS",
[
{"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"},
{
"id": -2,
"litellm_provider": "openai",
"model_name": "gpt-free",
"api_key": "k1",
},
{
"id": -1,
"provider": "OPENAI",
"litellm_provider": "openai",
"model_name": "gpt-prem",
"api_key": "k2",
"billing_tier": "premium",
@ -111,13 +200,13 @@ async def test_premium_eligible_auto_prefers_premium_over_free(monkeypatch):
from app.config import config
session = _FakeSession(_thread())
monkeypatch.setattr(
_set_global_llm_configs(
monkeypatch,
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": -2,
"provider": "OPENAI",
"litellm_provider": "openai",
"model_name": "gpt-free",
"api_key": "k1",
"billing_tier": "free",
@ -125,7 +214,7 @@ async def test_premium_eligible_auto_prefers_premium_over_free(monkeypatch):
},
{
"id": -1,
"provider": "OPENAI",
"litellm_provider": "openai",
"model_name": "gpt-prem",
"api_key": "k2",
"billing_tier": "premium",
@ -154,17 +243,19 @@ async def test_premium_eligible_auto_prefers_premium_over_free(monkeypatch):
@pytest.mark.asyncio
async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch):
async def test_premium_eligible_auto_uses_quality_pool_not_single_preferred_model(
monkeypatch,
):
from app.config import config
session = _FakeSession(_thread())
monkeypatch.setattr(
_set_global_llm_configs(
monkeypatch,
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": -1,
"provider": "AZURE_OPENAI",
"litellm_provider": "azure",
"model_name": "gpt-5.1",
"api_key": "k1",
"billing_tier": "premium",
@ -173,7 +264,7 @@ async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch):
},
{
"id": -2,
"provider": "AZURE_OPENAI",
"litellm_provider": "azure",
"model_name": "gpt-5.4",
"api_key": "k2",
"billing_tier": "premium",
@ -182,12 +273,39 @@ async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch):
},
{
"id": -3,
"provider": "OPENROUTER",
"model_name": "openai/gpt-5.4",
"litellm_provider": "anthropic",
"model_name": "claude-opus",
"api_key": "k3",
"billing_tier": "premium",
"auto_pin_tier": "B",
"quality_score": 100,
"auto_pin_tier": "A",
"quality_score": 99,
},
{
"id": -4,
"litellm_provider": "openai",
"model_name": "gpt-5.3",
"api_key": "k4",
"billing_tier": "premium",
"auto_pin_tier": "A",
"quality_score": 98,
},
{
"id": -5,
"litellm_provider": "gemini",
"model_name": "gemini-3-pro",
"api_key": "k5",
"billing_tier": "premium",
"auto_pin_tier": "A",
"quality_score": 97,
},
{
"id": -6,
"litellm_provider": "xai",
"model_name": "grok-5",
"api_key": "k6",
"billing_tier": "premium",
"auto_pin_tier": "A",
"quality_score": 96,
},
],
)
@ -207,7 +325,7 @@ async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch):
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=0,
)
assert result.resolved_llm_config_id == -2
assert result.resolved_llm_config_id in {-1, -3, -4, -5, -6}
assert result.resolved_tier == "premium"
@ -216,13 +334,13 @@ async def test_next_turn_reuses_existing_pin(monkeypatch):
from app.config import config
session = _FakeSession(_thread(pinned_llm_config_id=-1))
monkeypatch.setattr(
_set_global_llm_configs(
monkeypatch,
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": -1,
"provider": "OPENAI",
"litellm_provider": "openai",
"model_name": "gpt-prem",
"api_key": "k2",
"billing_tier": "premium",
@ -257,13 +375,13 @@ async def test_premium_eligible_auto_can_pin_premium(monkeypatch):
from app.config import config
session = _FakeSession(_thread())
monkeypatch.setattr(
_set_global_llm_configs(
monkeypatch,
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": -1,
"provider": "OPENAI",
"litellm_provider": "openai",
"model_name": "gpt-prem",
"api_key": "k2",
"billing_tier": "premium",
@ -295,20 +413,20 @@ async def test_premium_ineligible_auto_pins_free_only(monkeypatch):
from app.config import config
session = _FakeSession(_thread())
monkeypatch.setattr(
_set_global_llm_configs(
monkeypatch,
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": -2,
"provider": "OPENAI",
"litellm_provider": "openai",
"model_name": "gpt-free",
"api_key": "k1",
"billing_tier": "free",
},
{
"id": -1,
"provider": "OPENAI",
"litellm_provider": "openai",
"model_name": "gpt-prem",
"api_key": "k2",
"billing_tier": "premium",
@ -340,20 +458,20 @@ async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch):
from app.config import config
session = _FakeSession(_thread(pinned_llm_config_id=-1))
monkeypatch.setattr(
_set_global_llm_configs(
monkeypatch,
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": -2,
"provider": "OPENAI",
"litellm_provider": "openai",
"model_name": "gpt-free",
"api_key": "k1",
"billing_tier": "free",
},
{
"id": -1,
"provider": "OPENAI",
"litellm_provider": "openai",
"model_name": "gpt-prem",
"api_key": "k2",
"billing_tier": "premium",
@ -385,20 +503,20 @@ async def test_force_repin_free_switches_auto_premium_pin_to_free(monkeypatch):
from app.config import config
session = _FakeSession(_thread(pinned_llm_config_id=-1))
monkeypatch.setattr(
_set_global_llm_configs(
monkeypatch,
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": -2,
"provider": "OPENAI",
"litellm_provider": "openai",
"model_name": "gpt-free",
"api_key": "k1",
"billing_tier": "free",
},
{
"id": -1,
"provider": "OPENAI",
"litellm_provider": "openai",
"model_name": "gpt-prem",
"api_key": "k2",
"billing_tier": "premium",
@ -433,11 +551,16 @@ async def test_explicit_user_model_change_clears_pin(monkeypatch):
from app.config import config
session = _FakeSession(_thread(pinned_llm_config_id=-2))
monkeypatch.setattr(
_set_global_llm_configs(
monkeypatch,
config,
"GLOBAL_LLM_CONFIGS",
[
{"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"},
{
"id": -2,
"litellm_provider": "openai",
"model_name": "gpt-free",
"api_key": "k1",
},
],
)
@ -458,11 +581,16 @@ async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch):
from app.config import config
session = _FakeSession(_thread(pinned_llm_config_id=-999))
monkeypatch.setattr(
_set_global_llm_configs(
monkeypatch,
config,
"GLOBAL_LLM_CONFIGS",
[
{"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"},
{
"id": -2,
"litellm_provider": "openai",
"model_name": "gpt-free",
"api_key": "k1",
},
],
)
@ -487,7 +615,7 @@ async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch):
# ---------------------------------------------------------------------------
# Quality-aware pin selection (Auto Fastest upgrade)
# Quality-aware pin selection (Auto upgrade)
# ---------------------------------------------------------------------------
@ -498,13 +626,13 @@ async def test_health_gated_config_is_excluded_from_selection(monkeypatch):
from app.config import config
session = _FakeSession(_thread())
monkeypatch.setattr(
_set_global_llm_configs(
monkeypatch,
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": -1,
"provider": "OPENROUTER",
"litellm_provider": "openrouter",
"model_name": "venice/dead-model",
"api_key": "k1",
"billing_tier": "free",
@ -514,7 +642,7 @@ async def test_health_gated_config_is_excluded_from_selection(monkeypatch):
},
{
"id": -2,
"provider": "OPENROUTER",
"litellm_provider": "openrouter",
"model_name": "google/gemini-flash",
"api_key": "k1",
"billing_tier": "free",
@ -550,13 +678,13 @@ async def test_tier_a_locks_first_premium_user_skips_or(monkeypatch):
from app.config import config
session = _FakeSession(_thread())
monkeypatch.setattr(
_set_global_llm_configs(
monkeypatch,
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": -1,
"provider": "AZURE_OPENAI",
"litellm_provider": "azure",
"model_name": "gpt-5",
"api_key": "k-yaml",
"billing_tier": "premium",
@ -566,7 +694,7 @@ async def test_tier_a_locks_first_premium_user_skips_or(monkeypatch):
},
{
"id": -2,
"provider": "OPENROUTER",
"litellm_provider": "openrouter",
"model_name": "openai/gpt-5",
"api_key": "k-or",
"billing_tier": "premium",
@ -602,13 +730,13 @@ async def test_tier_a_falls_through_to_or_when_a_pool_empty_for_user(monkeypatch
from app.config import config
session = _FakeSession(_thread())
monkeypatch.setattr(
_set_global_llm_configs(
monkeypatch,
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": -1,
"provider": "AZURE_OPENAI",
"litellm_provider": "azure",
"model_name": "gpt-5",
"api_key": "k-yaml",
"billing_tier": "premium",
@ -618,7 +746,7 @@ async def test_tier_a_falls_through_to_or_when_a_pool_empty_for_user(monkeypatch
},
{
"id": -2,
"provider": "OPENROUTER",
"litellm_provider": "openrouter",
"model_name": "google/gemini-flash:free",
"api_key": "k-or",
"billing_tier": "free",
@ -656,7 +784,7 @@ async def test_top_k_picks_only_high_score_models(monkeypatch):
high_score_cfgs = [
{
"id": -i,
"provider": "AZURE_OPENAI",
"litellm_provider": "azure",
"model_name": f"gpt-x-{i}",
"api_key": "k",
"billing_tier": "premium",
@ -668,7 +796,7 @@ async def test_top_k_picks_only_high_score_models(monkeypatch):
]
low_score_trap = {
"id": -99,
"provider": "AZURE_OPENAI",
"litellm_provider": "azure",
"model_name": "tiny-legacy",
"api_key": "k",
"billing_tier": "premium",
@ -676,9 +804,9 @@ async def test_top_k_picks_only_high_score_models(monkeypatch):
"quality_score": 10,
"health_gated": False,
}
monkeypatch.setattr(
_set_global_llm_configs(
monkeypatch,
config,
"GLOBAL_LLM_CONFIGS",
[*high_score_cfgs, low_score_trap],
)
@ -723,13 +851,13 @@ async def test_pin_reuse_survives_health_gating_for_existing_pin(monkeypatch):
from app.config import config
session = _FakeSession(_thread(pinned_llm_config_id=-1))
monkeypatch.setattr(
_set_global_llm_configs(
monkeypatch,
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": -1,
"provider": "OPENROUTER",
"litellm_provider": "openrouter",
"model_name": "venice/dead-model",
"api_key": "k",
"billing_tier": "premium",
@ -739,7 +867,7 @@ async def test_pin_reuse_survives_health_gating_for_existing_pin(monkeypatch):
},
{
"id": -2,
"provider": "AZURE_OPENAI",
"litellm_provider": "azure",
"model_name": "gpt-5",
"api_key": "k",
"billing_tier": "premium",
@ -775,13 +903,13 @@ async def test_pin_reuse_regression_existing_healthy_pin(monkeypatch):
from app.config import config
session = _FakeSession(_thread(pinned_llm_config_id=-1))
monkeypatch.setattr(
_set_global_llm_configs(
monkeypatch,
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": -1,
"provider": "AZURE_OPENAI",
"litellm_provider": "azure",
"model_name": "gpt-5",
"api_key": "k",
"billing_tier": "premium",
@ -791,7 +919,7 @@ async def test_pin_reuse_regression_existing_healthy_pin(monkeypatch):
},
{
"id": -2,
"provider": "AZURE_OPENAI",
"litellm_provider": "azure",
"model_name": "gpt-5-pro",
"api_key": "k",
"billing_tier": "premium",
@ -833,13 +961,13 @@ async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch):
from app.config import config
session = _FakeSession(_thread(pinned_llm_config_id=-1))
monkeypatch.setattr(
_set_global_llm_configs(
monkeypatch,
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": -1,
"provider": "OPENROUTER",
"litellm_provider": "openrouter",
"model_name": "google/gemma-4-26b-a4b-it:free",
"api_key": "k",
"billing_tier": "free",
@ -849,7 +977,7 @@ async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch):
},
{
"id": -2,
"provider": "OPENROUTER",
"litellm_provider": "openrouter",
"model_name": "google/gemini-2.5-flash:free",
"api_key": "k",
"billing_tier": "free",
@ -881,18 +1009,86 @@ async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch):
assert result.from_existing_pin is False
def test_mark_runtime_cooldown_writes_shared_redis(monkeypatch):
import app.services.auto_model_pin_service as svc
mark_runtime_cooldown(-9, reason="provider_rate_limited", cooldown_seconds=123)
redis_client = svc._runtime_cooldown_redis
assert redis_client.values["auto:cooldown:llm:-9"] == "provider_rate_limited"
assert redis_client.ttls["auto:cooldown:llm:-9"] == 123
@pytest.mark.asyncio
async def test_shared_runtime_cooldown_blocks_pin_across_workers(monkeypatch):
"""A Redis cooldown written by another worker should invalidate local pins."""
import app.services.auto_model_pin_service as svc
from app.config import config
session = _FakeSession(_thread(pinned_llm_config_id=-1))
_set_global_llm_configs(
monkeypatch,
config,
[
{
"id": -1,
"litellm_provider": "openrouter",
"model_name": "google/gemma-4-26b-a4b-it:free",
"api_key": "k",
"billing_tier": "free",
"auto_pin_tier": "C",
"quality_score": 90,
"health_gated": False,
},
{
"id": -2,
"litellm_provider": "openrouter",
"model_name": "google/gemini-2.5-flash:free",
"api_key": "k",
"billing_tier": "free",
"auto_pin_tier": "C",
"quality_score": 80,
"health_gated": False,
},
],
)
svc._runtime_cooldown_redis.set(
"auto:cooldown:llm:-1",
"provider_rate_limited",
ex=600,
)
async def _blocked(*_args, **_kwargs):
return _FakeQuotaResult(allowed=False)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage",
_blocked,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=0,
)
assert result.resolved_llm_config_id == -2
assert result.from_existing_pin is False
@pytest.mark.asyncio
async def test_clearing_runtime_cooldown_restores_pin_reuse(monkeypatch):
from app.config import config
session = _FakeSession(_thread(pinned_llm_config_id=-1))
monkeypatch.setattr(
_set_global_llm_configs(
monkeypatch,
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": -1,
"provider": "OPENROUTER",
"litellm_provider": "openrouter",
"model_name": "google/gemma-4-26b-a4b-it:free",
"api_key": "k",
"billing_tier": "free",
@ -931,13 +1127,13 @@ async def test_auto_pin_repin_excludes_previous_config_on_runtime_retry(monkeypa
from app.config import config
session = _FakeSession(_thread(pinned_llm_config_id=-1))
monkeypatch.setattr(
_set_global_llm_configs(
monkeypatch,
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": -1,
"provider": "OPENROUTER",
"litellm_provider": "openrouter",
"model_name": "google/gemma-4-26b-a4b-it:free",
"api_key": "k",
"billing_tier": "free",
@ -947,7 +1143,7 @@ async def test_auto_pin_repin_excludes_previous_config_on_runtime_retry(monkeypa
},
{
"id": -2,
"provider": "OPENROUTER",
"litellm_provider": "openrouter",
"model_name": "google/gemini-2.5-flash:free",
"api_key": "k",
"billing_tier": "free",

View file

@ -45,8 +45,9 @@ class _FakeQuotaResult:
class _FakeExecResult:
def __init__(self, thread):
def __init__(self, *, thread=None, scalars=None):
self._thread = thread
self._scalars = scalars or []
def unique(self):
return self
@ -54,14 +55,21 @@ class _FakeExecResult:
def scalar_one_or_none(self):
return self._thread
def scalars(self):
return SimpleNamespace(all=lambda: self._scalars)
class _FakeSession:
def __init__(self, thread):
self.thread = thread
self.commit_count = 0
self.execute_count = 0
async def execute(self, _stmt):
return _FakeExecResult(self.thread)
self.execute_count += 1
if self.execute_count == 1:
return _FakeExecResult(thread=self.thread)
return _FakeExecResult(scalars=[])
async def commit(self):
self.commit_count += 1
@ -71,10 +79,64 @@ def _thread(*, pinned: int | None = None):
return SimpleNamespace(id=1, search_space_id=10, pinned_llm_config_id=pinned)
def _set_global_llm_configs(monkeypatch, config, configs: list[dict]):
from app.services.provider_capabilities import derive_supports_image_input
connections = []
models = []
for cfg in configs:
config_id = int(cfg["id"])
connection_id = config_id - 100_000
provider = cfg.get("provider") or cfg.get("litellm_provider")
model_name = cfg["model_name"]
if "supports_image_input" not in cfg:
litellm_params = cfg.get("litellm_params") or {}
base_model = (
litellm_params.get("base_model")
if isinstance(litellm_params, dict)
else None
)
cfg["supports_image_input"] = derive_supports_image_input(
provider=provider,
model_name=model_name,
base_model=base_model,
custom_provider=cfg.get("custom_provider"),
)
connections.append(
{
"id": connection_id,
"provider": provider,
"scope": "GLOBAL",
"enabled": True,
}
)
model = {
"id": config_id,
"connection_id": connection_id,
"model_id": model_name,
"display_name": cfg.get("name") or model_name,
"supports_chat": cfg.get("supports_chat", True),
"supports_tools": cfg.get("supports_tools", True),
"supports_image_generation": cfg.get("supports_image_generation", False),
"capabilities_override": cfg.get("capabilities_override") or {},
"billing_tier": cfg.get("billing_tier", "free"),
"catalog": {
"auto_pin_tier": cfg.get("auto_pin_tier"),
"quality_score": cfg.get("quality_score"),
},
"supports_image_input": cfg["supports_image_input"],
}
models.append(model)
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", configs)
monkeypatch.setattr(config, "GLOBAL_CONNECTIONS", connections)
monkeypatch.setattr(config, "GLOBAL_MODELS", models)
def _vision_cfg(id_: int, *, tier: str = "free", quality: int = 80) -> dict:
return {
"id": id_,
"provider": "OPENAI",
"litellm_provider": "openai",
"model_name": f"vision-{id_}",
"api_key": "k",
"billing_tier": tier,
@ -87,7 +149,7 @@ def _vision_cfg(id_: int, *, tier: str = "free", quality: int = 80) -> dict:
def _text_only_cfg(id_: int, *, tier: str = "free", quality: int = 90) -> dict:
return {
"id": id_,
"provider": "OPENAI",
"litellm_provider": "openai",
"model_name": f"text-{id_}",
"api_key": "k",
"billing_tier": tier,
@ -108,11 +170,7 @@ async def test_image_turn_filters_out_text_only_candidates(monkeypatch):
from app.config import config
session = _FakeSession(_thread())
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[_text_only_cfg(-1), _vision_cfg(-2)],
)
_set_global_llm_configs(monkeypatch, config, [_text_only_cfg(-1), _vision_cfg(-2)])
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage",
_premium_allowed,
@ -140,11 +198,7 @@ async def test_image_turn_force_repins_stale_text_only_pin(monkeypatch):
from app.config import config
session = _FakeSession(_thread(pinned=-1))
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[_text_only_cfg(-1), _vision_cfg(-2)],
)
_set_global_llm_configs(monkeypatch, config, [_text_only_cfg(-1), _vision_cfg(-2)])
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage",
_premium_allowed,
@ -172,9 +226,9 @@ async def test_image_turn_reuses_existing_vision_pin(monkeypatch):
from app.config import config
session = _FakeSession(_thread(pinned=-2))
monkeypatch.setattr(
_set_global_llm_configs(
monkeypatch,
config,
"GLOBAL_LLM_CONFIGS",
[_text_only_cfg(-1), _vision_cfg(-2), _vision_cfg(-3, quality=70)],
)
monkeypatch.setattr(
@ -203,10 +257,8 @@ async def test_image_turn_with_no_vision_candidates_raises(monkeypatch):
from app.config import config
session = _FakeSession(_thread())
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[_text_only_cfg(-1), _text_only_cfg(-2)],
_set_global_llm_configs(
monkeypatch, config, [_text_only_cfg(-1), _text_only_cfg(-2)]
)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage",
@ -231,11 +283,7 @@ async def test_non_image_turn_keeps_text_only_in_pool(monkeypatch):
from app.config import config
session = _FakeSession(_thread())
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[_text_only_cfg(-1)],
)
_set_global_llm_configs(monkeypatch, config, [_text_only_cfg(-1)])
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage",
_premium_allowed,
@ -261,7 +309,7 @@ async def test_image_turn_unannotated_cfg_resolves_via_helper(monkeypatch):
session = _FakeSession(_thread())
cfg_unannotated_vision = {
"id": -2,
"provider": "OPENAI",
"litellm_provider": "openai",
"model_name": "gpt-4o", # known vision model in LiteLLM map
"api_key": "k",
"billing_tier": "free",
@ -269,7 +317,7 @@ async def test_image_turn_unannotated_cfg_resolves_via_helper(monkeypatch):
"quality_score": 80,
# NOTE: no supports_image_input key
}
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [cfg_unannotated_vision])
_set_global_llm_configs(monkeypatch, config, [cfg_unannotated_vision])
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage",
_premium_allowed,

View file

@ -1,19 +1,4 @@
"""Defense-in-depth: image-gen call sites must not let an empty
``api_base`` fall through to LiteLLM's module-global ``litellm.api_base``.
The bug repro: an OpenRouter image-gen config ships
``api_base=""``. The pre-fix call site in
``image_generation_routes._execute_image_generation`` did
``if cfg.get("api_base"): kwargs["api_base"] = cfg["api_base"]`` which
silently dropped the empty string. LiteLLM then fell back to
``litellm.api_base`` (commonly inherited from ``AZURE_OPENAI_ENDPOINT``)
and OpenRouter's ``image_generation/transformation`` appended
``/chat/completions`` to it 404 ``Resource not found``.
This test pins the post-fix behaviour: with an empty ``api_base`` in
the config, the call site MUST set ``api_base`` to OpenRouter's public
URL instead of leaving it unset.
"""
"""Image-gen call sites must pass each config's explicit ``api_base``."""
from __future__ import annotations
@ -26,22 +11,23 @@ pytestmark = pytest.mark.unit
@pytest.mark.asyncio
async def test_global_openrouter_image_gen_sets_api_base_when_config_empty():
"""The global-config branch (``config_id < 0``) of
``_execute_image_generation`` must apply the resolver and pin
``api_base`` to OpenRouter when the config ships an empty string.
"""
async def test_global_openrouter_image_gen_sets_explicit_api_base():
"""The global-config branch forwards the explicit OpenRouter base."""
from app.routes import image_generation_routes
cfg = {
global_model = {
"id": -20_001,
"name": "GPT Image 1 (OpenRouter)",
"provider": "OPENROUTER",
"model_name": "openai/gpt-image-1",
"connection_id": -101,
"model_id": "openai/gpt-image-1",
"supports_image_generation": True,
"capabilities_override": {},
}
global_connection = {
"id": -101,
"provider": "openrouter",
"api_key": "sk-or-test",
"api_base": "", # the original bug shape
"api_version": None,
"litellm_params": {},
"base_url": "https://openrouter.ai/api/v1",
"extra": {},
}
captured: dict = {}
@ -51,7 +37,7 @@ async def test_global_openrouter_image_gen_sets_api_base_when_config_empty():
return MagicMock(model_dump=lambda: {"data": []}, _hidden_params={})
image_gen = MagicMock()
image_gen.image_generation_config_id = cfg["id"]
image_gen.image_gen_model_id = global_model["id"]
image_gen.prompt = "test"
image_gen.n = 1
image_gen.quality = None
@ -61,14 +47,19 @@ async def test_global_openrouter_image_gen_sets_api_base_when_config_empty():
image_gen.model = None
search_space = MagicMock()
search_space.image_generation_config_id = cfg["id"]
search_space.image_gen_model_id = global_model["id"]
session = MagicMock()
with (
patch.object(
image_generation_routes,
"_get_global_image_gen_config",
return_value=cfg,
"_get_global_model",
return_value=global_model,
),
patch.object(
image_generation_routes,
"_get_global_connection",
return_value=global_connection,
),
patch.object(
image_generation_routes,
@ -80,30 +71,31 @@ async def test_global_openrouter_image_gen_sets_api_base_when_config_empty():
session=session, image_gen=image_gen, search_space=search_space
)
# The whole point of the fix: even with empty ``api_base`` in the
# config, we forward OpenRouter's public URL so the call doesn't
# inherit an Azure endpoint.
assert captured.get("api_base") == "https://openrouter.ai/api/v1"
assert captured["model"] == "openrouter/openai/gpt-image-1"
@pytest.mark.asyncio
async def test_generate_image_tool_global_sets_api_base_when_config_empty():
"""Same defense at the agent tool entry point — both surfaces share
async def test_generate_image_tool_global_sets_explicit_api_base():
"""Same explicit-base behavior at the agent tool entry point — both surfaces share
the same OpenRouter config payloads."""
from app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.tools import (
generate_image as gi_module,
)
cfg = {
global_model = {
"id": -20_001,
"name": "GPT Image 1 (OpenRouter)",
"provider": "OPENROUTER",
"model_name": "openai/gpt-image-1",
"connection_id": -101,
"model_id": "openai/gpt-image-1",
"supports_image_generation": True,
"capabilities_override": {},
}
global_connection = {
"id": -101,
"provider": "openrouter",
"api_key": "sk-or-test",
"api_base": "",
"api_version": None,
"litellm_params": {},
"base_url": "https://openrouter.ai/api/v1",
"extra": {},
}
captured: dict = {}
@ -119,7 +111,7 @@ async def test_generate_image_tool_global_sets_api_base_when_config_empty():
search_space = MagicMock()
search_space.id = 1
search_space.image_generation_config_id = cfg["id"]
search_space.image_gen_model_id = global_model["id"]
session_cm = AsyncMock()
session = AsyncMock()
@ -142,7 +134,10 @@ async def test_generate_image_tool_global_sets_api_base_when_config_empty():
with (
patch.object(gi_module, "shielded_async_session", return_value=session_cm),
patch.object(gi_module, "_get_global_image_gen_config", return_value=cfg),
patch.object(gi_module, "_get_global_model", return_value=global_model),
patch.object(
gi_module, "_get_global_connection", return_value=global_connection
),
patch.object(
gi_module, "aimage_generation", side_effect=fake_aimage_generation
),
@ -171,20 +166,16 @@ async def test_generate_image_tool_global_sets_api_base_when_config_empty():
assert captured["model"] == "openrouter/openai/gpt-image-1"
def test_image_gen_router_deployment_sets_api_base_when_config_empty():
"""The Auto-mode router pool must also resolve ``api_base`` when an
OpenRouter config ships an empty string. The deployment dict is fed
straight to ``litellm.Router``, so a missing ``api_base`` would
leak the same way as the direct call sites.
"""
def test_image_gen_router_deployment_sets_explicit_api_base():
"""The Auto-mode router pool carries explicit api_base into deployments."""
from app.services.image_gen_router_service import ImageGenRouterService
deployment = ImageGenRouterService._config_to_deployment(
{
"model_name": "openai/gpt-image-1",
"provider": "OPENROUTER",
"litellm_provider": "openrouter",
"api_key": "sk-or-test",
"api_base": "",
"api_base": "https://openrouter.ai/api/v1",
}
)
assert deployment is not None

View file

@ -25,10 +25,10 @@ def _fake_yaml_config(
return {
"id": id,
"name": f"yaml-{id}",
"provider": "OPENAI",
"litellm_provider": "openai",
"model_name": model_name,
"api_key": "sk-test",
"api_base": "",
"api_base": "https://api.openai.com/v1",
"billing_tier": billing_tier,
"rpm": 100,
"tpm": 100_000,
@ -54,10 +54,10 @@ def _fake_openrouter_config(
return {
"id": id,
"name": f"or-{id}",
"provider": "OPENROUTER",
"litellm_provider": "openrouter",
"model_name": model_name,
"api_key": "sk-or-test",
"api_base": "",
"api_base": "https://openrouter.ai/api/v1",
"billing_tier": billing_tier,
"rpm": 20 if billing_tier == "free" else 200,
"tpm": 100_000 if billing_tier == "free" else 1_000_000,
@ -217,10 +217,64 @@ def test_auto_model_pin_candidates_include_dynamic_openrouter():
model_name="meta-llama/llama-3.3-70b:free",
billing_tier="free",
)
original = config.GLOBAL_LLM_CONFIGS
global_connections = [
{
"id": -110_001,
"provider": "openrouter",
"scope": "GLOBAL",
"enabled": True,
},
{
"id": -110_002,
"provider": "openrouter",
"scope": "GLOBAL",
"enabled": True,
},
]
global_models = [
{
"id": or_premium["id"],
"connection_id": -110_001,
"model_id": or_premium["model_name"],
"display_name": or_premium["name"],
"supports_chat": True,
"supports_image_input": True,
"supports_tools": True,
"supports_image_generation": False,
"capabilities_override": {},
"billing_tier": or_premium["billing_tier"],
"catalog": {
"auto_pin_tier": "A",
"quality_score": 50,
},
},
{
"id": or_free["id"],
"connection_id": -110_002,
"model_id": or_free["model_name"],
"display_name": or_free["name"],
"supports_chat": True,
"supports_image_input": True,
"supports_tools": True,
"supports_image_generation": False,
"capabilities_override": {},
"billing_tier": or_free["billing_tier"],
"catalog": {
"auto_pin_tier": "A",
"quality_score": 50,
},
},
]
original_configs = config.GLOBAL_LLM_CONFIGS
original_connections = config.GLOBAL_CONNECTIONS
original_models = config.GLOBAL_MODELS
try:
config.GLOBAL_LLM_CONFIGS = [or_premium, or_free]
config.GLOBAL_CONNECTIONS = global_connections
config.GLOBAL_MODELS = global_models
candidate_ids = {c["id"] for c in _global_candidates()}
assert candidate_ids == {-10_001, -10_002}
finally:
config.GLOBAL_LLM_CONFIGS = original
config.GLOBAL_LLM_CONFIGS = original_configs
config.GLOBAL_CONNECTIONS = original_connections
config.GLOBAL_MODELS = original_models

View file

@ -0,0 +1,78 @@
from app.services.global_model_catalog import materialize_global_model_catalog
from app.services.model_resolver import ensure_v1, to_litellm
def test_openai_compatible_resolver_uses_explicit_api_base() -> None:
model, kwargs = to_litellm(
{
"protocol": "OPENAI_COMPATIBLE",
"provider": "openai",
"base_url": "http://host.docker.internal:1234/v1",
"api_key": "local-key",
"extra": {},
},
"qwen/qwen3",
)
assert model == "openai/qwen/qwen3"
assert kwargs["api_base"] == "http://host.docker.internal:1234/v1"
assert kwargs["api_key"] == "local-key"
assert ensure_v1("http://example.com/v1") == "http://example.com/v1"
def test_ollama_resolver_uses_native_api_base() -> None:
model, kwargs = to_litellm(
{
"protocol": "OLLAMA",
"provider": "ollama_chat",
"base_url": "http://host.docker.internal:11434",
"api_key": None,
"extra": {},
},
"llama3.2",
)
assert model == "ollama_chat/llama3.2"
assert kwargs["api_base"] == "http://host.docker.internal:11434"
def test_global_materialization_preserves_tier_and_keeps_key_server_side() -> None:
connections, models = materialize_global_model_catalog(
chat_configs=[
{
"id": -101,
"name": "OpenRouter Free",
"litellm_provider": "openrouter",
"model_name": "meta-llama/llama-3.1-8b-instruct:free",
"api_key": "sk-global-secret",
"api_base": "https://openrouter.ai/api/v1",
"billing_tier": "free",
"anonymous_enabled": True,
"seo_enabled": True,
"rpm": 10,
"tpm": 1000,
},
{
"id": -102,
"name": "OpenRouter Premium",
"litellm_provider": "openrouter",
"model_name": "anthropic/claude-sonnet-4",
"api_key": "sk-global-secret",
"api_base": "https://openrouter.ai/api/v1",
"billing_tier": "premium",
},
],
image_configs=[],
)
assert len(connections) == 1
assert connections[0]["api_key"] == "sk-global-secret"
assert {model["billing_tier"] for model in models} == {"free", "premium"}
assert models[0]["catalog"]["anonymous_enabled"] is True
assert models[0]["catalog"]["rpm"] == 10
public_connections = [
{key: value for key, value in connection.items() if key != "api_key"}
for connection in connections
]
assert "sk-" not in repr(public_connections)

View file

@ -217,7 +217,7 @@ def test_generate_configs_drops_non_text_and_non_tool_models():
# ---------------------------------------------------------------------------
# _generate_image_gen_configs / _generate_vision_llm_configs
# _generate_image_gen_configs
# ---------------------------------------------------------------------------
@ -263,18 +263,15 @@ def test_generate_image_gen_configs_filters_by_image_output():
# Each config must carry ``billing_tier`` for routing in image_generation_routes.
for c in cfgs:
assert c["billing_tier"] in {"free", "premium"}
assert c["provider"] == "OPENROUTER"
assert c["provider"] == "openrouter"
assert c[_OPENROUTER_DYNAMIC_MARKER] is True
# Defense-in-depth: emit the OpenRouter base URL at source so a
# downstream call site that forgets ``resolve_api_base`` still
# doesn't 404 against an inherited Azure endpoint.
# Emit the OpenRouter base URL at source so every call path passes an
# explicit api_base and cannot inherit a process-global endpoint.
assert c["api_base"] == "https://openrouter.ai/api/v1"
def test_generate_image_gen_configs_assigns_image_id_offset():
"""Image configs use a different id_offset (-20000) so their negative
IDs don't collide with chat configs (-10000) or vision configs (-30000).
"""
"""Image configs use their own id_offset (-20000)."""
from app.services.openrouter_integration_service import (
_generate_image_gen_configs,
)
@ -291,90 +288,3 @@ def test_generate_image_gen_configs_assigns_image_id_offset():
cfgs = _generate_image_gen_configs(raw, dict(_SETTINGS_BASE))
assert all(c["id"] < -20_000 + 1 for c in cfgs)
assert all(c["id"] > -29_000_000 for c in cfgs)
def test_generate_vision_llm_configs_filters_by_image_input_text_output():
"""Vision LLMs must accept image input AND emit text — pure image-gen
(no text out) and text-only (no image in) models are excluded.
"""
from app.services.openrouter_integration_service import (
_generate_vision_llm_configs,
)
raw = [
# GPT-4o: vision LLM (image in, text out) — must emit.
{
"id": "openai/gpt-4o",
"architecture": {
"input_modalities": ["text", "image"],
"output_modalities": ["text"],
},
"context_length": 128_000,
"pricing": {"prompt": "0.000005", "completion": "0.000015"},
},
# Pure image generator — image *output*, no text out. Must NOT emit.
{
"id": "openai/gpt-image-1",
"architecture": {
"input_modalities": ["text"],
"output_modalities": ["image"],
},
"context_length": 4_000,
"pricing": {"prompt": "0", "completion": "0"},
},
# Pure text model (no image in). Must NOT emit.
{
"id": "anthropic/claude-3-haiku",
"architecture": {
"input_modalities": ["text"],
"output_modalities": ["text"],
},
"context_length": 200_000,
"pricing": {"prompt": "0.000001", "completion": "0.000005"},
},
]
cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE))
names = {c["model_name"] for c in cfgs}
assert names == {"openai/gpt-4o"}
cfg = cfgs[0]
assert cfg["billing_tier"] == "premium"
# Pricing carried inline so pricing_registration can register vision
# under ``openrouter/openai/gpt-4o`` even if the chat catalogue cache
# is cleared.
assert cfg["input_cost_per_token"] == pytest.approx(5e-6)
assert cfg["output_cost_per_token"] == pytest.approx(15e-6)
assert cfg[_OPENROUTER_DYNAMIC_MARKER] is True
# Defense-in-depth: emit the OpenRouter base URL at source so a
# downstream call site that forgets ``resolve_api_base`` still
# doesn't inherit an Azure endpoint.
assert cfg["api_base"] == "https://openrouter.ai/api/v1"
def test_generate_vision_llm_configs_drops_chat_only_filters():
"""A small-context vision model that doesn't advertise tool calling is
still a valid vision LLM for "describe this image" prompts. The chat
filters (``supports_tool_calling``, ``has_sufficient_context``) must
NOT be applied to vision emission.
"""
from app.services.openrouter_integration_service import (
_generate_vision_llm_configs,
)
raw = [
{
"id": "tiny/vision-mini",
"architecture": {
"input_modalities": ["text", "image"],
"output_modalities": ["text"],
},
"supported_parameters": [], # no tools
"context_length": 4_000, # well below MIN_CONTEXT_LENGTH
"pricing": {"prompt": "0.0000001", "completion": "0.0000005"},
}
]
cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE))
assert len(cfgs) == 1
assert cfgs[0]["model_name"] == "tiny/vision-mini"

View file

@ -25,7 +25,7 @@ def _or_cfg(
) -> dict:
return {
"id": cid,
"provider": "OPENROUTER",
"provider": "openrouter",
"model_name": model_name,
"billing_tier": tier,
"auto_pin_tier": "B" if tier == "premium" else "C",
@ -144,7 +144,7 @@ async def test_enrich_health_only_touches_or_provider(monkeypatch):
"""YAML cfgs that aren't OPENROUTER must be skipped entirely."""
yaml_cfg = {
"id": -1,
"provider": "AZURE_OPENAI",
"litellm_provider": "azure",
"model_name": "gpt-5",
"billing_tier": "premium",
"auto_pin_tier": "A",
@ -313,7 +313,7 @@ async def test_enrich_health_no_or_cfgs_is_noop(monkeypatch):
"""When the catalogue has no OR cfgs at all, no HTTP calls fire."""
yaml_cfg: dict[str, Any] = {
"id": -1,
"provider": "AZURE_OPENAI",
"litellm_provider": "azure",
"model_name": "gpt-5",
"billing_tier": "premium",
}

View file

@ -186,7 +186,7 @@ def test_openrouter_models_register_under_aliases(monkeypatch):
[
{
"id": 1,
"provider": "OPENROUTER",
"litellm_provider": "openrouter",
"model_name": "anthropic/claude-3-5-sonnet",
}
],
@ -228,7 +228,7 @@ def test_yaml_override_registers_under_alias_set(monkeypatch):
[
{
"id": 1,
"provider": "AZURE_OPENAI",
"litellm_provider": "azure",
"model_name": "gpt-5.4",
"litellm_params": {
"base_model": "gpt-5.4",
@ -243,7 +243,6 @@ def test_yaml_override_registers_under_alias_set(monkeypatch):
keys = spy.all_keys
assert "gpt-5.4" in keys
assert "azure_openai/gpt-5.4" in keys
assert "azure/gpt-5.4" in keys
payload = spy.calls[0]
@ -271,7 +270,7 @@ def test_no_override_means_no_registration(monkeypatch):
[
{
"id": 1,
"provider": "OPENAI",
"litellm_provider": "openai",
"model_name": "gpt-4o",
"litellm_params": {"base_model": "gpt-4o"},
}
@ -302,7 +301,7 @@ def test_openrouter_skipped_when_pricing_missing(monkeypatch):
[
{
"id": 1,
"provider": "OPENROUTER",
"litellm_provider": "openrouter",
"model_name": "anthropic/claude-3-5-sonnet",
}
],
@ -349,12 +348,12 @@ def test_register_continues_after_individual_failure(monkeypatch, caplog):
[
{
"id": 1,
"provider": "OPENROUTER",
"litellm_provider": "openrouter",
"model_name": "anthropic/claude-3-5-sonnet",
},
{
"id": 2,
"provider": "OPENAI",
"litellm_provider": "openai",
"model_name": "custom-deployment",
"litellm_params": {
"base_model": "custom-deployment",
@ -369,79 +368,3 @@ def test_register_continues_after_individual_failure(monkeypatch, caplog):
# The good config still registered.
assert any("custom-deployment" in payload for payload in successful_calls)
def test_vision_configs_registered_with_chat_shape(monkeypatch):
"""``register_pricing_from_global_configs`` walks
``GLOBAL_VISION_LLM_CONFIGS`` in addition to the chat configs so vision
calls (during indexing) bill correctly. Vision configs use the same
chat-shape token prices, but image-gen pricing is intentionally NOT
registered here (handled via ``response_cost`` in LiteLLM).
"""
from app.config import config
from app.services.pricing_registration import register_pricing_from_global_configs
spy = _patch_register(monkeypatch)
_patch_openrouter_pricing(
monkeypatch,
{"openai/gpt-4o": {"prompt": "0.000005", "completion": "0.000015"}},
)
# No chat configs — only vision. Proves the vision walk is a separate
# iteration, not piggy-backed on the chat list.
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [])
monkeypatch.setattr(
config,
"GLOBAL_VISION_LLM_CONFIGS",
[
{
"id": -1,
"provider": "OPENROUTER",
"model_name": "openai/gpt-4o",
"billing_tier": "premium",
"input_cost_per_token": 5e-6,
"output_cost_per_token": 15e-6,
}
],
)
register_pricing_from_global_configs()
assert "openrouter/openai/gpt-4o" in spy.all_keys
payload_value = spy.calls[0]["openrouter/openai/gpt-4o"]
assert payload_value["mode"] == "chat"
assert payload_value["litellm_provider"] == "openrouter"
assert payload_value["input_cost_per_token"] == pytest.approx(5e-6)
assert payload_value["output_cost_per_token"] == pytest.approx(15e-6)
def test_vision_with_inline_pricing_when_or_cache_missing(monkeypatch):
"""If the OpenRouter pricing cache misses a vision model (different
catalogue surface), the vision walk falls back to inline
``input_cost_per_token``/``output_cost_per_token`` on the cfg itself.
"""
from app.config import config
from app.services.pricing_registration import register_pricing_from_global_configs
spy = _patch_register(monkeypatch)
_patch_openrouter_pricing(monkeypatch, {})
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [])
monkeypatch.setattr(
config,
"GLOBAL_VISION_LLM_CONFIGS",
[
{
"id": -1,
"provider": "OPENROUTER",
"model_name": "google/gemini-2.5-flash",
"billing_tier": "premium",
"input_cost_per_token": 1e-6,
"output_cost_per_token": 4e-6,
}
],
)
register_pricing_from_global_configs()
assert "openrouter/google/gemini-2.5-flash" in spy.all_keys

View file

@ -1,107 +0,0 @@
"""Unit tests for the shared ``api_base`` resolver.
The cascade exists so vision and image-gen call sites can't silently
inherit ``litellm.api_base`` (commonly set by ``AZURE_OPENAI_ENDPOINT``)
when an OpenRouter / Groq / etc. config ships an empty string. See
``provider_api_base`` module docstring for the original repro
(OpenRouter image-gen 404-ing against an Azure endpoint).
"""
from __future__ import annotations
import pytest
from app.services.provider_api_base import (
PROVIDER_DEFAULT_API_BASE,
PROVIDER_KEY_DEFAULT_API_BASE,
resolve_api_base,
)
pytestmark = pytest.mark.unit
def test_config_value_wins_over_defaults():
"""A non-empty config value is always returned verbatim, even when the
provider has a default the operator gets the last word."""
result = resolve_api_base(
provider="OPENROUTER",
provider_prefix="openrouter",
config_api_base="https://my-openrouter-mirror.example.com/v1",
)
assert result == "https://my-openrouter-mirror.example.com/v1"
def test_provider_key_default_when_config_missing():
"""``DEEPSEEK`` shares the ``openai`` LiteLLM prefix but has its own
base URL the provider-key map must take precedence over the prefix
map so DeepSeek requests don't go to OpenAI."""
result = resolve_api_base(
provider="DEEPSEEK",
provider_prefix="openai",
config_api_base=None,
)
assert result == PROVIDER_KEY_DEFAULT_API_BASE["DEEPSEEK"]
def test_provider_prefix_default_when_no_key_default():
result = resolve_api_base(
provider="OPENROUTER",
provider_prefix="openrouter",
config_api_base=None,
)
assert result == PROVIDER_DEFAULT_API_BASE["openrouter"]
def test_unknown_provider_returns_none():
"""When neither map matches we return ``None`` so the caller can let
LiteLLM apply its own provider-integration default (Azure deployment
URL, custom-provider URL, etc.)."""
result = resolve_api_base(
provider="SOMETHING_NEW",
provider_prefix="something_new",
config_api_base=None,
)
assert result is None
def test_empty_string_config_treated_as_missing():
"""The original bug: OpenRouter dynamic configs ship ``api_base=""``
and downstream call sites use ``if cfg.get("api_base"):`` empty
strings are falsy in Python but the cascade has to step in anyway."""
result = resolve_api_base(
provider="OPENROUTER",
provider_prefix="openrouter",
config_api_base="",
)
assert result == PROVIDER_DEFAULT_API_BASE["openrouter"]
def test_whitespace_only_config_treated_as_missing():
"""A config value of ``" "`` is a configuration mistake — treat it
as missing instead of forwarding whitespace to LiteLLM (which would
almost certainly 404)."""
result = resolve_api_base(
provider="OPENROUTER",
provider_prefix="openrouter",
config_api_base=" ",
)
assert result == PROVIDER_DEFAULT_API_BASE["openrouter"]
def test_provider_case_insensitive():
"""Some call sites pass the provider lowercase (DB enum value), others
uppercase (YAML key). Both must resolve."""
upper = resolve_api_base(
provider="DEEPSEEK", provider_prefix="openai", config_api_base=None
)
lower = resolve_api_base(
provider="deepseek", provider_prefix="openai", config_api_base=None
)
assert upper == lower == PROVIDER_KEY_DEFAULT_API_BASE["DEEPSEEK"]
def test_all_inputs_none_returns_none():
assert (
resolve_api_base(provider=None, provider_prefix=None, config_api_base=None)
is None
)

View file

@ -32,7 +32,7 @@ pytestmark = pytest.mark.unit
def test_or_modalities_with_image_returns_true():
assert (
derive_supports_image_input(
provider="OPENROUTER",
provider="openrouter",
model_name="openai/gpt-4o",
openrouter_input_modalities=["text", "image"],
)
@ -43,7 +43,7 @@ def test_or_modalities_with_image_returns_true():
def test_or_modalities_text_only_returns_false():
assert (
derive_supports_image_input(
provider="OPENROUTER",
provider="openrouter",
model_name="deepseek/deepseek-v3.2-exp",
openrouter_input_modalities=["text"],
)
@ -57,7 +57,7 @@ def test_or_modalities_empty_list_returns_false():
to LiteLLM."""
assert (
derive_supports_image_input(
provider="OPENROUTER",
provider="openrouter",
model_name="weird/empty-modalities",
openrouter_input_modalities=[],
)
@ -70,7 +70,7 @@ def test_or_modalities_none_falls_through_to_litellm():
to LiteLLM. Using ``openai/gpt-4o`` which is in LiteLLM's map."""
assert (
derive_supports_image_input(
provider="OPENAI",
provider="openai",
model_name="gpt-4o",
openrouter_input_modalities=None,
)
@ -86,7 +86,7 @@ def test_or_modalities_none_falls_through_to_litellm():
def test_litellm_known_vision_model_returns_true():
assert (
derive_supports_image_input(
provider="OPENAI",
provider="openai",
model_name="gpt-4o",
)
is True
@ -100,7 +100,7 @@ def test_litellm_base_model_wins_over_model_name():
doesn't know) would shadow the real capability."""
assert (
derive_supports_image_input(
provider="AZURE_OPENAI",
provider="azure",
model_name="my-azure-deployment-id",
base_model="gpt-4o",
)
@ -112,7 +112,7 @@ def test_litellm_unknown_model_default_allows():
"""Default-allow on unknown — the safety net is the actual block."""
assert (
derive_supports_image_input(
provider="CUSTOM",
provider="custom",
model_name="brand-new-model-x9-unmapped",
custom_provider="brand_new_proxy",
)
@ -128,7 +128,7 @@ def test_litellm_known_text_only_returns_false():
# Sanity: confirm the helper's negative path. We use a small model
# known not to support vision per the map.
result = derive_supports_image_input(
provider="DEEPSEEK",
provider="openai",
model_name="deepseek-chat",
)
# We accept either False (LiteLLM said explicit no) or True
@ -147,7 +147,7 @@ def test_litellm_known_text_only_returns_false():
def test_is_known_text_only_returns_false_for_vision_model():
assert (
is_known_text_only_chat_model(
provider="OPENAI",
provider="openai",
model_name="gpt-4o",
)
is False
@ -160,7 +160,7 @@ def test_is_known_text_only_returns_false_for_unknown_model():
fixing."""
assert (
is_known_text_only_chat_model(
provider="CUSTOM",
provider="custom",
model_name="brand-new-model-x9-unmapped",
custom_provider="brand_new_proxy",
)
@ -181,7 +181,7 @@ def test_is_known_text_only_returns_false_when_lookup_raises(monkeypatch):
assert (
is_known_text_only_chat_model(
provider="OPENAI",
provider="openai",
model_name="gpt-4o",
)
is False
@ -201,7 +201,7 @@ def test_is_known_text_only_returns_true_on_explicit_false(monkeypatch):
assert (
is_known_text_only_chat_model(
provider="OPENAI",
provider="openai",
model_name="any-model",
)
is True
@ -218,7 +218,7 @@ def test_is_known_text_only_returns_false_on_supports_vision_true(monkeypatch):
assert (
is_known_text_only_chat_model(
provider="OPENAI",
provider="openai",
model_name="any-model",
)
is False
@ -237,7 +237,7 @@ def test_is_known_text_only_returns_false_on_missing_key(monkeypatch):
assert (
is_known_text_only_chat_model(
provider="OPENAI",
provider="openai",
model_name="any-model",
)
is False

View file

@ -1,4 +1,4 @@
"""Unit tests for the Auto (Fastest) quality scoring module."""
"""Unit tests for the Auto quality scoring module."""
from __future__ import annotations
@ -228,7 +228,7 @@ def test_static_score_or_recent_release_beats_year_old_same_provider():
def test_static_score_yaml_includes_operator_bonus():
cfg = {
"provider": "AZURE_OPENAI",
"litellm_provider": "azure",
"model_name": "gpt-5",
"litellm_params": {"base_model": "azure/gpt-5"},
}
@ -238,7 +238,7 @@ def test_static_score_yaml_includes_operator_bonus():
def test_static_score_yaml_unknown_provider_still_carries_bonus():
cfg = {
"provider": "SOME_NEW_PROVIDER",
"litellm_provider": "some_new_provider",
"model_name": "weird-model",
}
score = static_score_yaml(cfg)
@ -247,7 +247,7 @@ def test_static_score_yaml_unknown_provider_still_carries_bonus():
def test_static_score_yaml_clamped_0_to_100():
cfg = {
"provider": "AZURE_OPENAI",
"litellm_provider": "azure",
"model_name": "gpt-5",
"litellm_params": {"base_model": "azure/gpt-5"},
}

View file

@ -131,6 +131,10 @@ def test_serialized_calls_includes_cost_micros():
assert serialized == [
{
"model": "m",
"model_ref": None,
"model_id": None,
"display_name": None,
"provider": None,
"prompt_tokens": 1,
"completion_tokens": 1,
"total_tokens": 2,

View file

@ -1,89 +0,0 @@
"""Defense-in-depth: vision-LLM resolution must not leak ``api_base``
defaults from ``litellm.api_base`` either.
Vision shares the same shape as image-gen global YAML / OpenRouter
dynamic configs ship ``api_base=""`` and the pre-fix ``get_vision_llm``
call sites would silently drop the empty string and inherit
``AZURE_OPENAI_ENDPOINT``. ``ChatLiteLLM(...)`` doesn't 404 on
construction so we test the kwargs we hand to it instead.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
pytestmark = pytest.mark.unit
@pytest.mark.asyncio
async def test_get_vision_llm_global_openrouter_sets_api_base():
"""Global negative-ID branch: an OpenRouter vision config with
``api_base=""`` must end up calling ``SanitizedChatLiteLLM`` with
``api_base="https://openrouter.ai/api/v1"`` never an empty string,
never silently absent."""
from app.services import llm_service
cfg = {
"id": -30_001,
"name": "GPT-4o Vision (OpenRouter)",
"provider": "OPENROUTER",
"model_name": "openai/gpt-4o",
"api_key": "sk-or-test",
"api_base": "",
"api_version": None,
"litellm_params": {},
"billing_tier": "free",
}
search_space = MagicMock()
search_space.id = 1
search_space.user_id = "user-x"
search_space.vision_llm_config_id = cfg["id"]
session = AsyncMock()
scalars = MagicMock()
scalars.first.return_value = search_space
result = MagicMock()
result.scalars.return_value = scalars
session.execute.return_value = result
captured: dict = {}
class FakeSanitized:
def __init__(self, **kwargs):
captured.update(kwargs)
with (
patch(
"app.services.vision_llm_router_service.get_global_vision_llm_config",
return_value=cfg,
),
patch(
"app.agents.chat.runtime.llm_config.SanitizedChatLiteLLM",
new=FakeSanitized,
),
):
await llm_service.get_vision_llm(session=session, search_space_id=1)
assert captured.get("api_base") == "https://openrouter.ai/api/v1"
assert captured["model"] == "openrouter/openai/gpt-4o"
def test_vision_router_deployment_sets_api_base_when_config_empty():
"""Auto-mode vision router: deployments are fed to ``litellm.Router``,
so the resolver has to apply at deployment construction time too."""
from app.services.vision_llm_router_service import VisionLLMRouterService
deployment = VisionLLMRouterService._config_to_deployment(
{
"model_name": "openai/gpt-4o",
"provider": "OPENROUTER",
"api_key": "sk-or-test",
"api_base": "",
}
)
assert deployment is not None
assert deployment["litellm_params"]["api_base"] == "https://openrouter.ai/api/v1"
assert deployment["litellm_params"]["model"] == "openrouter/openai/gpt-4o"

View file

@ -0,0 +1,79 @@
from __future__ import annotations
import pytest
from app.services.llm_error_adapter import LLMErrorCategory, adapt_llm_exception
from app.tasks.chat.streaming.errors.classifier import classify_stream_exception
pytestmark = pytest.mark.unit
def _exception_named(name: str, message: str) -> Exception:
return type(name, (Exception,), {})(message)
def test_adapter_classifies_authentication_error_by_class_name() -> None:
exc = _exception_named("AuthenticationError", "provider rejected credentials")
adapted = adapt_llm_exception(exc)
assert adapted.category is LLMErrorCategory.AUTH_FAILED
assert adapted.retryable is False
assert adapted.user_message == "LLM authentication failed. Check your API key."
def test_adapter_classifies_embedded_provider_401_payload() -> None:
exc = RuntimeError(
'litellm.AuthenticationError: OpenrouterException - {"error":{"message":"User not found.","code":401}}'
)
adapted = adapt_llm_exception(exc)
assert adapted.category is LLMErrorCategory.AUTH_FAILED
assert adapted.provider_status_code == 401
def test_adapter_preserves_rate_limit_classification() -> None:
exc = RuntimeError('{"error":{"message":"Slow down","code":429}}')
adapted = adapt_llm_exception(exc)
assert adapted.category is LLMErrorCategory.RATE_LIMITED
assert adapted.retryable is True
def test_stream_classifier_maps_model_auth_to_stable_code() -> None:
exc = RuntimeError(
'litellm.AuthenticationError: OpenrouterException - {"error":{"message":"User not found.","code":401}}'
)
kind, code, severity, expected, message, extra = classify_stream_exception(
exc,
flow_label="chat",
)
assert kind == "model_auth_failed"
assert code == "MODEL_AUTH_FAILED"
assert severity == "warn"
assert expected is True
assert "API key" in message
assert extra == {
"provider_error_category": "auth_failed",
"provider_status_code": 401,
}
def test_stream_classifier_keeps_unknown_errors_generic() -> None:
exc = RuntimeError("database exploded")
kind, code, severity, expected, message, extra = classify_stream_exception(
exc,
flow_label="chat",
)
assert kind == "server_error"
assert code == "SERVER_ERROR"
assert severity == "error"
assert expected is False
assert message == "Error during chat: database exploded"
assert extra is None

View file

@ -0,0 +1,61 @@
"""Unit tests for provider-safe LLM history normalization."""
from __future__ import annotations
import pytest
from app.tasks.chat.llm_history_normalizer import (
assistant_content_to_llm_text,
user_content_to_llm_content,
)
pytestmark = pytest.mark.unit
def test_assistant_ui_parts_drop_thinking_steps_for_llm_history() -> None:
content = [
{"type": "data-thinking-steps", "data": [{"id": "thinking-1"}]},
{"type": "text", "text": "visible answer"},
]
assert assistant_content_to_llm_text(content) == "visible answer"
def test_provider_thinking_blocks_are_not_replayed_to_llm() -> None:
content = [
{"type": "thinking", "thinking": "private reasoning"},
{"type": "text", "text": "final answer"},
]
assert assistant_content_to_llm_text(content) == "final answer"
def test_unknown_assistant_blocks_are_dropped() -> None:
content = [
{"type": "redacted_thinking", "data": "hidden"},
{"type": "tool_use", "name": "search"},
{"type": "text", "text": "kept"},
]
assert assistant_content_to_llm_text(content) == "kept"
def test_user_images_convert_to_openai_compatible_image_url_blocks() -> None:
content = [
{"type": "text", "text": "look"},
{"type": "image", "image": "data:image/png;base64,abc"},
]
assert user_content_to_llm_content(content, allow_images=True) == [
{"type": "text", "text": "look"},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
]
def test_user_images_can_be_dropped_for_text_only_history() -> None:
content = [
{"type": "text", "text": "look"},
{"type": "image", "image": "data:image/png;base64,abc"},
]
assert user_content_to_llm_content(content, allow_images=False) == "look"

View file

@ -0,0 +1,67 @@
"""Unit tests for final assistant message part normalization."""
from __future__ import annotations
import pytest
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from app.tasks.chat.message_parts_normalizer import (
final_assistant_parts_from_messages,
merge_streamed_and_final_parts,
normalize_ai_message_to_parts,
)
pytestmark = pytest.mark.unit
def test_string_ai_message_content_becomes_text_part() -> None:
assert normalize_ai_message_to_parts(AIMessage(content="hello")) == [
{"type": "text", "text": "hello"}
]
def test_deepseek_thinking_plus_text_blocks_backfill_only_text() -> None:
message = AIMessage(
content=[
{"type": "thinking", "thinking": "hidden reasoning"},
{"type": "text", "text": "Yo bro! What's up?"},
],
additional_kwargs={"reasoning_content": "hidden reasoning"},
)
assert normalize_ai_message_to_parts(message) == [
{"type": "text", "text": "Yo bro! What's up?"}
]
def test_final_parts_use_last_ai_message_and_skip_trailing_tool_messages() -> None:
messages = [
HumanMessage(content="ask"),
AIMessage(content="draft"),
ToolMessage(content="tool output", tool_call_id="tc-1"),
AIMessage(content=[{"type": "text", "text": "final answer"}]),
ToolMessage(content="trailing tool noise", tool_call_id="tc-2"),
]
assert final_assistant_parts_from_messages(messages) == [
{"type": "text", "text": "final answer"}
]
def test_merge_adds_final_text_when_stream_only_has_thinking_steps() -> None:
streamed = [
{
"type": "data-thinking-steps",
"data": [{"id": "thinking-1", "status": "completed"}],
}
]
final = [{"type": "text", "text": "visible answer"}]
assert merge_streamed_and_final_parts(streamed, final) == [*streamed, *final]
def test_merge_does_not_duplicate_when_stream_already_has_text() -> None:
streamed = [{"type": "text", "text": "streamed answer"}]
final = [{"type": "text", "text": "final answer"}]
assert merge_streamed_and_final_parts(streamed, final) == streamed

View file

@ -35,7 +35,7 @@ def test_safety_net_does_not_fire_for_azure_gpt_4o():
it text-only."""
assert (
is_known_text_only_chat_model(
provider="AZURE_OPENAI",
provider="azure",
model_name="my-azure-deployment",
base_model="gpt-4o",
)
@ -49,7 +49,7 @@ def test_safety_net_does_not_fire_for_unknown_model():
LiteLLM doesn't know about must flow through to the provider."""
assert (
is_known_text_only_chat_model(
provider="CUSTOM",
provider="custom",
custom_provider="brand_new_proxy",
model_name="brand-new-model-x9",
)
@ -69,7 +69,7 @@ def test_safety_net_does_not_fire_when_lookup_raises(monkeypatch):
assert (
is_known_text_only_chat_model(
provider="OPENAI",
provider="openai",
model_name="gpt-4o",
)
is False
@ -88,7 +88,7 @@ def test_safety_net_fires_only_on_explicit_false(monkeypatch):
monkeypatch.setattr(pc.litellm, "get_model_info", _info_explicit_false)
assert (
is_known_text_only_chat_model(
provider="OPENAI",
provider="openai",
model_name="text-only-stub",
)
is True
@ -100,7 +100,7 @@ def test_safety_net_fires_only_on_explicit_false(monkeypatch):
monkeypatch.setattr(pc.litellm, "get_model_info", _info_true)
assert (
is_known_text_only_chat_model(
provider="OPENAI",
provider="openai",
model_name="vision-stub",
)
is False
@ -112,7 +112,7 @@ def test_safety_net_fires_only_on_explicit_false(monkeypatch):
monkeypatch.setattr(pc.litellm, "get_model_info", _info_missing)
assert (
is_known_text_only_chat_model(
provider="OPENAI",
provider="openai",
model_name="missing-key-stub",
)
is False

Some files were not shown because too many files have changed in this diff Show more