mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-16 21:05:20 +02:00
Merge remote-tracking branch 'upstream/dev' into features/documents-injestion-layered-cached
This commit is contained in:
commit
32a6e54ce6
215 changed files with 9532 additions and 15405 deletions
|
|
@ -4,7 +4,7 @@ Revision ID: 138
|
|||
Revises: 137
|
||||
Create Date: 2026-04-30
|
||||
|
||||
Add a single thread-level column to persist the Auto (Fastest) model pin:
|
||||
Add a single thread-level column to persist the Auto model pin:
|
||||
- pinned_llm_config_id: concrete resolved global LLM config id used for this
|
||||
thread. NULL means "no pin; Auto will resolve on next turn".
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,19 @@ down_revision: str | None = "157"
|
|||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
PUBLICATION_NAME = "zero_publication"
|
||||
TARGET_STATUS_LABELS = (
|
||||
"pending",
|
||||
"awaiting_brief",
|
||||
"drafting",
|
||||
"awaiting_review",
|
||||
"rendering",
|
||||
"ready",
|
||||
"failed",
|
||||
"cancelled",
|
||||
)
|
||||
LEGACY_STATUS_LABELS = ("pending", "generating", "ready", "failed")
|
||||
|
||||
|
||||
def _drop_podcasts_from_publication() -> None:
|
||||
"""Detach podcasts from zero_publication so status can be retyped.
|
||||
|
|
@ -28,31 +41,103 @@ def _drop_podcasts_from_publication() -> None:
|
|||
published = conn.execute(
|
||||
sa.text(
|
||||
"SELECT 1 FROM pg_publication_tables "
|
||||
"WHERE pubname = 'zero_publication' "
|
||||
"WHERE pubname = :publication "
|
||||
"AND schemaname = current_schema() AND tablename = 'podcasts'"
|
||||
)
|
||||
),
|
||||
{"publication": PUBLICATION_NAME},
|
||||
).fetchone()
|
||||
if published:
|
||||
op.execute('ALTER PUBLICATION "zero_publication" DROP TABLE "podcasts";')
|
||||
op.execute(f'ALTER PUBLICATION "{PUBLICATION_NAME}" DROP TABLE "podcasts";')
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
_drop_podcasts_from_publication()
|
||||
def _enum_labels(type_name: str) -> list[str] | None:
|
||||
rows = (
|
||||
op.get_bind()
|
||||
.execute(
|
||||
sa.text(
|
||||
"SELECT e.enumlabel "
|
||||
"FROM pg_type t "
|
||||
"JOIN pg_namespace n ON n.oid = t.typnamespace "
|
||||
"JOIN pg_enum e ON e.enumtypid = t.oid "
|
||||
"WHERE n.nspname = current_schema() AND t.typname = :type_name "
|
||||
"ORDER BY e.enumsortorder"
|
||||
),
|
||||
{"type_name": type_name},
|
||||
)
|
||||
.fetchall()
|
||||
)
|
||||
if not rows:
|
||||
return None
|
||||
return [str(row[0]) for row in rows]
|
||||
|
||||
# Retype the status enum by swapping in a fresh type and casting existing
|
||||
# rows. The legacy transient value 'generating' maps onto 'rendering'.
|
||||
op.execute("ALTER TYPE podcast_status RENAME TO podcast_status_old;")
|
||||
|
||||
def _column_type_name(table: str, column: str) -> str | None:
|
||||
row = (
|
||||
op.get_bind()
|
||||
.execute(
|
||||
sa.text(
|
||||
"SELECT udt_name "
|
||||
"FROM information_schema.columns "
|
||||
"WHERE table_schema = current_schema() "
|
||||
"AND table_name = :table AND column_name = :column"
|
||||
),
|
||||
{"table": table, "column": column},
|
||||
)
|
||||
.fetchone()
|
||||
)
|
||||
return str(row[0]) if row else None
|
||||
|
||||
|
||||
def _ensure_status_enum(
|
||||
*,
|
||||
desired_labels: tuple[str, ...],
|
||||
temporary_type: str,
|
||||
create_sql: str,
|
||||
alter_sql: str,
|
||||
default_value: str,
|
||||
) -> None:
|
||||
current_labels = _enum_labels("podcast_status")
|
||||
desired = list(desired_labels)
|
||||
|
||||
if current_labels != desired:
|
||||
if current_labels is None:
|
||||
if _enum_labels(temporary_type) is None:
|
||||
raise RuntimeError("podcast_status enum is missing")
|
||||
elif _enum_labels(temporary_type) is None:
|
||||
op.execute(f"ALTER TYPE podcast_status RENAME TO {temporary_type};")
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"podcast_status and its temporary replacement both exist"
|
||||
)
|
||||
|
||||
if _enum_labels("podcast_status") is None:
|
||||
op.execute(create_sql)
|
||||
|
||||
if _enum_labels("podcast_status") != desired:
|
||||
raise RuntimeError("podcast_status enum is not in the expected shape")
|
||||
|
||||
op.execute("ALTER TABLE podcasts ALTER COLUMN status DROP DEFAULT;")
|
||||
if _column_type_name("podcasts", "status") != "podcast_status":
|
||||
op.execute(alter_sql)
|
||||
op.execute(
|
||||
"""
|
||||
f"ALTER TABLE podcasts ALTER COLUMN status SET DEFAULT '{default_value}';"
|
||||
)
|
||||
|
||||
if _enum_labels(temporary_type) is not None:
|
||||
op.execute(f"DROP TYPE {temporary_type};")
|
||||
|
||||
|
||||
def _upgrade_status_enum() -> None:
|
||||
_ensure_status_enum(
|
||||
desired_labels=TARGET_STATUS_LABELS,
|
||||
temporary_type="podcast_status_old",
|
||||
create_sql="""
|
||||
CREATE TYPE podcast_status AS ENUM (
|
||||
'pending', 'awaiting_brief', 'drafting', 'awaiting_review',
|
||||
'rendering', 'ready', 'failed', 'cancelled'
|
||||
);
|
||||
"""
|
||||
)
|
||||
op.execute("ALTER TABLE podcasts ALTER COLUMN status DROP DEFAULT;")
|
||||
op.execute(
|
||||
"""
|
||||
""",
|
||||
alter_sql="""
|
||||
ALTER TABLE podcasts
|
||||
ALTER COLUMN status TYPE podcast_status
|
||||
USING (
|
||||
|
|
@ -61,10 +146,43 @@ def upgrade() -> None:
|
|||
ELSE status::text
|
||||
END
|
||||
)::podcast_status;
|
||||
"""
|
||||
""",
|
||||
default_value="pending",
|
||||
)
|
||||
op.execute("ALTER TABLE podcasts ALTER COLUMN status SET DEFAULT 'pending';")
|
||||
op.execute("DROP TYPE podcast_status_old;")
|
||||
|
||||
|
||||
def _downgrade_status_enum() -> None:
|
||||
_ensure_status_enum(
|
||||
desired_labels=LEGACY_STATUS_LABELS,
|
||||
temporary_type="podcast_status_new",
|
||||
create_sql=(
|
||||
"CREATE TYPE podcast_status AS ENUM "
|
||||
"('pending', 'generating', 'ready', 'failed');"
|
||||
),
|
||||
alter_sql="""
|
||||
ALTER TABLE podcasts
|
||||
ALTER COLUMN status TYPE podcast_status
|
||||
USING (
|
||||
CASE status::text
|
||||
WHEN 'awaiting_brief' THEN 'pending'
|
||||
WHEN 'drafting' THEN 'generating'
|
||||
WHEN 'awaiting_review' THEN 'generating'
|
||||
WHEN 'rendering' THEN 'generating'
|
||||
WHEN 'cancelled' THEN 'failed'
|
||||
ELSE status::text
|
||||
END
|
||||
)::podcast_status;
|
||||
""",
|
||||
default_value="ready",
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
_drop_podcasts_from_publication()
|
||||
|
||||
# Retype the status enum by swapping in a fresh type and casting existing
|
||||
# rows. The legacy transient value 'generating' maps onto 'rendering'.
|
||||
_upgrade_status_enum()
|
||||
|
||||
op.execute("ALTER TABLE podcasts ADD COLUMN IF NOT EXISTS source_content TEXT;")
|
||||
op.execute("ALTER TABLE podcasts ADD COLUMN IF NOT EXISTS spec JSONB;")
|
||||
|
|
@ -83,6 +201,8 @@ def upgrade() -> None:
|
|||
|
||||
|
||||
def downgrade() -> None:
|
||||
_drop_podcasts_from_publication()
|
||||
|
||||
op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS error;")
|
||||
op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS duration_seconds;")
|
||||
op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS storage_key;")
|
||||
|
|
@ -92,27 +212,4 @@ def downgrade() -> None:
|
|||
op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS source_content;")
|
||||
|
||||
# Collapse the expanded lifecycle back onto the original four values.
|
||||
op.execute("ALTER TYPE podcast_status RENAME TO podcast_status_new;")
|
||||
op.execute(
|
||||
"CREATE TYPE podcast_status AS ENUM "
|
||||
"('pending', 'generating', 'ready', 'failed');"
|
||||
)
|
||||
op.execute("ALTER TABLE podcasts ALTER COLUMN status DROP DEFAULT;")
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE podcasts
|
||||
ALTER COLUMN status TYPE podcast_status
|
||||
USING (
|
||||
CASE status::text
|
||||
WHEN 'awaiting_brief' THEN 'pending'
|
||||
WHEN 'drafting' THEN 'generating'
|
||||
WHEN 'awaiting_review' THEN 'generating'
|
||||
WHEN 'rendering' THEN 'generating'
|
||||
WHEN 'cancelled' THEN 'failed'
|
||||
ELSE status::text
|
||||
END
|
||||
)::podcast_status;
|
||||
"""
|
||||
)
|
||||
op.execute("ALTER TABLE podcasts ALTER COLUMN status SET DEFAULT 'ready';")
|
||||
op.execute("DROP TYPE podcast_status_new;")
|
||||
_downgrade_status_enum()
|
||||
|
|
|
|||
299
surfsense_backend/alembic/versions/160_add_model_connections.py
Normal file
299
surfsense_backend/alembic/versions/160_add_model_connections.py
Normal file
|
|
@ -0,0 +1,299 @@
|
|||
"""add model connections
|
||||
|
||||
Revision ID: 160
|
||||
Revises: 159
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "160"
|
||||
down_revision: str | None = "159"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
connection_scope = postgresql.ENUM(
|
||||
"GLOBAL",
|
||||
"SEARCH_SPACE",
|
||||
"USER",
|
||||
name="connectionscope",
|
||||
create_type=False,
|
||||
)
|
||||
model_source = postgresql.ENUM(
|
||||
"DISCOVERED",
|
||||
"MANUAL",
|
||||
name="modelsource",
|
||||
create_type=False,
|
||||
)
|
||||
|
||||
|
||||
def _table_exists(table_name: str) -> bool:
|
||||
return table_name in sa.inspect(op.get_bind()).get_table_names()
|
||||
|
||||
|
||||
def _column_exists(table_name: str, column_name: str) -> bool:
|
||||
if not _table_exists(table_name):
|
||||
return False
|
||||
return column_name in {
|
||||
column["name"] for column in sa.inspect(op.get_bind()).get_columns(table_name)
|
||||
}
|
||||
|
||||
|
||||
def _index_exists(table_name: str, index_name: str) -> bool:
|
||||
if not _table_exists(table_name):
|
||||
return False
|
||||
return index_name in {
|
||||
index["name"] for index in sa.inspect(op.get_bind()).get_indexes(table_name)
|
||||
}
|
||||
|
||||
|
||||
def _create_index_if_missing(
|
||||
index_name: str,
|
||||
table_name: str,
|
||||
columns: list[str],
|
||||
) -> None:
|
||||
if not _index_exists(table_name, index_name):
|
||||
op.create_index(index_name, table_name, columns, unique=False)
|
||||
|
||||
|
||||
def _add_searchspace_column_if_missing(
|
||||
column_name: str,
|
||||
*,
|
||||
server_default: object | None = None,
|
||||
) -> None:
|
||||
if not _column_exists("searchspaces", column_name):
|
||||
op.add_column(
|
||||
"searchspaces",
|
||||
sa.Column(
|
||||
column_name,
|
||||
sa.Integer(),
|
||||
nullable=True,
|
||||
server_default=server_default,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _drop_column_if_exists(table_name: str, column_name: str) -> None:
|
||||
if _column_exists(table_name, column_name):
|
||||
op.drop_column(table_name, column_name)
|
||||
|
||||
|
||||
def _drop_index_if_exists(table_name: str, index_name: str) -> None:
|
||||
if _index_exists(table_name, index_name):
|
||||
op.drop_index(index_name, table_name=table_name)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
bind = op.get_bind()
|
||||
connection_scope.create(bind, checkfirst=True)
|
||||
model_source.create(bind, checkfirst=True)
|
||||
|
||||
if _table_exists("connections"):
|
||||
if _column_exists("connections", "litellm_provider") and not _column_exists(
|
||||
"connections", "provider"
|
||||
):
|
||||
op.alter_column(
|
||||
"connections",
|
||||
"litellm_provider",
|
||||
new_column_name="provider",
|
||||
existing_type=sa.String(length=100),
|
||||
existing_nullable=True,
|
||||
)
|
||||
op.alter_column(
|
||||
"connections",
|
||||
"provider",
|
||||
existing_type=sa.String(length=100),
|
||||
nullable=False,
|
||||
)
|
||||
elif _column_exists("connections", "native_provider") and not _column_exists(
|
||||
"connections", "provider"
|
||||
):
|
||||
op.alter_column(
|
||||
"connections",
|
||||
"native_provider",
|
||||
new_column_name="provider",
|
||||
existing_type=sa.String(length=100),
|
||||
existing_nullable=True,
|
||||
)
|
||||
op.alter_column(
|
||||
"connections",
|
||||
"provider",
|
||||
existing_type=sa.String(length=100),
|
||||
nullable=False,
|
||||
)
|
||||
elif not _column_exists("connections", "provider"):
|
||||
op.add_column(
|
||||
"connections",
|
||||
sa.Column("provider", sa.String(length=100), nullable=False),
|
||||
)
|
||||
_drop_index_if_exists("connections", "ix_connections_protocol")
|
||||
_drop_column_if_exists("connections", "protocol")
|
||||
else:
|
||||
op.create_table(
|
||||
"connections",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("provider", sa.String(length=100), nullable=False),
|
||||
sa.Column("base_url", sa.String(length=500), nullable=True),
|
||||
sa.Column("api_key", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"extra",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
server_default=sa.text("'{}'::jsonb"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("scope", connection_scope, nullable=False),
|
||||
sa.Column(
|
||||
"enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False
|
||||
),
|
||||
sa.Column("search_space_id", sa.Integer(), nullable=True),
|
||||
sa.Column("user_id", sa.UUID(), nullable=True),
|
||||
sa.CheckConstraint(
|
||||
"(scope = 'GLOBAL' AND search_space_id IS NULL AND user_id IS NULL) OR "
|
||||
"(scope = 'SEARCH_SPACE' AND search_space_id IS NOT NULL AND user_id IS NOT NULL) OR "
|
||||
"(scope = 'USER' AND user_id IS NOT NULL)",
|
||||
name="ck_connections_scope_owner",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["search_space_id"], ["searchspaces.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
if _index_exists(
|
||||
"connections", "ix_connections_native_provider"
|
||||
) and not _index_exists("connections", "ix_connections_provider"):
|
||||
op.execute(
|
||||
"ALTER INDEX ix_connections_native_provider "
|
||||
"RENAME TO ix_connections_provider"
|
||||
)
|
||||
if _index_exists(
|
||||
"connections", "ix_connections_litellm_provider"
|
||||
) and not _index_exists("connections", "ix_connections_provider"):
|
||||
op.execute(
|
||||
"ALTER INDEX ix_connections_litellm_provider "
|
||||
"RENAME TO ix_connections_provider"
|
||||
)
|
||||
_create_index_if_missing("ix_connections_provider", "connections", ["provider"])
|
||||
_create_index_if_missing("ix_connections_scope", "connections", ["scope"])
|
||||
|
||||
if not _table_exists("models"):
|
||||
op.create_table(
|
||||
"models",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("connection_id", sa.Integer(), nullable=False),
|
||||
sa.Column("model_id", sa.String(length=255), nullable=False),
|
||||
sa.Column("display_name", sa.String(length=255), nullable=True),
|
||||
sa.Column(
|
||||
"source",
|
||||
model_source,
|
||||
server_default="DISCOVERED",
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("supports_chat", sa.Boolean(), nullable=True),
|
||||
sa.Column("max_input_tokens", sa.Integer(), nullable=True),
|
||||
sa.Column("supports_image_input", sa.Boolean(), nullable=True),
|
||||
sa.Column("supports_tools", sa.Boolean(), nullable=True),
|
||||
sa.Column("supports_image_generation", sa.Boolean(), nullable=True),
|
||||
sa.Column(
|
||||
"capabilities_override",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
server_default=sa.text("'{}'::jsonb"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False
|
||||
),
|
||||
sa.Column("billing_tier", sa.String(length=50), nullable=True),
|
||||
sa.Column(
|
||||
"catalog",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
server_default=sa.text("'{}'::jsonb"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["connection_id"], ["connections.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint(
|
||||
"connection_id", "model_id", name="uq_models_connection_model_id"
|
||||
),
|
||||
)
|
||||
else:
|
||||
if not _column_exists("models", "supports_chat"):
|
||||
op.add_column(
|
||||
"models", sa.Column("supports_chat", sa.Boolean(), nullable=True)
|
||||
)
|
||||
if not _column_exists("models", "max_input_tokens"):
|
||||
op.add_column(
|
||||
"models", sa.Column("max_input_tokens", sa.Integer(), nullable=True)
|
||||
)
|
||||
if not _column_exists("models", "supports_image_input"):
|
||||
op.add_column(
|
||||
"models", sa.Column("supports_image_input", sa.Boolean(), nullable=True)
|
||||
)
|
||||
if not _column_exists("models", "supports_tools"):
|
||||
op.add_column(
|
||||
"models", sa.Column("supports_tools", sa.Boolean(), nullable=True)
|
||||
)
|
||||
if not _column_exists("models", "supports_image_generation"):
|
||||
op.add_column(
|
||||
"models",
|
||||
sa.Column("supports_image_generation", sa.Boolean(), nullable=True),
|
||||
)
|
||||
_drop_column_if_exists("models", "capabilities")
|
||||
_drop_column_if_exists("models", "capabilities_declared")
|
||||
_drop_column_if_exists("models", "capabilities_verified")
|
||||
_create_index_if_missing("ix_models_connection_id", "models", ["connection_id"])
|
||||
_create_index_if_missing("ix_models_model_id", "models", ["model_id"])
|
||||
_create_index_if_missing("ix_models_billing_tier", "models", ["billing_tier"])
|
||||
|
||||
_add_searchspace_column_if_missing("chat_model_id", server_default=sa.text("0"))
|
||||
_add_searchspace_column_if_missing(
|
||||
"image_gen_model_id", server_default=sa.text("0")
|
||||
)
|
||||
_add_searchspace_column_if_missing("vision_model_id", server_default=sa.text("0"))
|
||||
for column_name in ("chat_model_id", "image_gen_model_id", "vision_model_id"):
|
||||
op.alter_column(
|
||||
"searchspaces",
|
||||
column_name,
|
||||
existing_type=sa.Integer(),
|
||||
existing_nullable=True,
|
||||
server_default=sa.text("0"),
|
||||
)
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE searchspaces
|
||||
SET
|
||||
chat_model_id = COALESCE(chat_model_id, 0),
|
||||
image_gen_model_id = COALESCE(image_gen_model_id, 0),
|
||||
vision_model_id = COALESCE(vision_model_id, 0)
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute("DROP TYPE IF EXISTS connectionprotocol")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("searchspaces", "vision_model_id")
|
||||
op.drop_column("searchspaces", "image_gen_model_id")
|
||||
op.drop_column("searchspaces", "chat_model_id")
|
||||
|
||||
op.drop_index(op.f("ix_models_billing_tier"), table_name="models")
|
||||
op.drop_index("ix_models_model_id", table_name="models")
|
||||
op.drop_index(op.f("ix_models_connection_id"), table_name="models")
|
||||
op.drop_table("models")
|
||||
|
||||
op.drop_index(op.f("ix_connections_scope"), table_name="connections")
|
||||
op.drop_index(op.f("ix_connections_provider"), table_name="connections")
|
||||
op.drop_table("connections")
|
||||
|
||||
bind = op.get_bind()
|
||||
model_source.drop(bind, checkfirst=True)
|
||||
connection_scope.drop(bind, checkfirst=True)
|
||||
|
|
@ -0,0 +1,270 @@
|
|||
"""remove legacy model config tables
|
||||
|
||||
Revision ID: 161
|
||||
Revises: 160
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.types import TypeEngine
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "161"
|
||||
down_revision: str | None = "160"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
litellm_provider = postgresql.ENUM(
|
||||
"OPENAI",
|
||||
"ANTHROPIC",
|
||||
"GOOGLE",
|
||||
"AZURE_OPENAI",
|
||||
"BEDROCK",
|
||||
"VERTEX_AI",
|
||||
"GROQ",
|
||||
"COHERE",
|
||||
"MISTRAL",
|
||||
"DEEPSEEK",
|
||||
"XAI",
|
||||
"OPENROUTER",
|
||||
"TOGETHER_AI",
|
||||
"FIREWORKS_AI",
|
||||
"REPLICATE",
|
||||
"PERPLEXITY",
|
||||
"OLLAMA",
|
||||
"ALIBABA_QWEN",
|
||||
"MOONSHOT",
|
||||
"ZHIPU",
|
||||
"ANYSCALE",
|
||||
"DEEPINFRA",
|
||||
"CEREBRAS",
|
||||
"SAMBANOVA",
|
||||
"AI21",
|
||||
"CLOUDFLARE",
|
||||
"DATABRICKS",
|
||||
"COMETAPI",
|
||||
"HUGGINGFACE",
|
||||
"GITHUB_MODELS",
|
||||
"MINIMAX",
|
||||
"CUSTOM",
|
||||
name="litellmprovider",
|
||||
create_type=False,
|
||||
)
|
||||
image_gen_provider = postgresql.ENUM(
|
||||
"OPENAI",
|
||||
"AZURE_OPENAI",
|
||||
"GOOGLE",
|
||||
"VERTEX_AI",
|
||||
"BEDROCK",
|
||||
"RECRAFT",
|
||||
"OPENROUTER",
|
||||
"XINFERENCE",
|
||||
"NSCALE",
|
||||
name="imagegenprovider",
|
||||
create_type=False,
|
||||
)
|
||||
vision_provider = postgresql.ENUM(
|
||||
"OPENAI",
|
||||
"ANTHROPIC",
|
||||
"GOOGLE",
|
||||
"AZURE_OPENAI",
|
||||
"VERTEX_AI",
|
||||
"BEDROCK",
|
||||
"XAI",
|
||||
"OPENROUTER",
|
||||
"OLLAMA",
|
||||
"GROQ",
|
||||
"TOGETHER_AI",
|
||||
"FIREWORKS_AI",
|
||||
"DEEPSEEK",
|
||||
"MISTRAL",
|
||||
"CUSTOM",
|
||||
name="visionprovider",
|
||||
create_type=False,
|
||||
)
|
||||
|
||||
|
||||
def _table_exists(table_name: str) -> bool:
|
||||
return table_name in sa.inspect(op.get_bind()).get_table_names()
|
||||
|
||||
|
||||
def _column_exists(table_name: str, column_name: str) -> bool:
|
||||
if not _table_exists(table_name):
|
||||
return False
|
||||
return column_name in {
|
||||
column["name"] for column in sa.inspect(op.get_bind()).get_columns(table_name)
|
||||
}
|
||||
|
||||
|
||||
def _drop_column_if_exists(table_name: str, column_name: str) -> None:
|
||||
if _column_exists(table_name, column_name):
|
||||
op.drop_column(table_name, column_name)
|
||||
|
||||
|
||||
def _rename_column_if_exists(
|
||||
table_name: str,
|
||||
old_column_name: str,
|
||||
new_column_name: str,
|
||||
*,
|
||||
existing_type: TypeEngine,
|
||||
existing_nullable: bool = True,
|
||||
) -> None:
|
||||
if _column_exists(table_name, old_column_name) and not _column_exists(
|
||||
table_name, new_column_name
|
||||
):
|
||||
op.alter_column(
|
||||
table_name,
|
||||
old_column_name,
|
||||
new_column_name=new_column_name,
|
||||
existing_type=existing_type,
|
||||
existing_nullable=existing_nullable,
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
for table_name in (
|
||||
"new_llm_configs",
|
||||
"vision_llm_configs",
|
||||
"image_generation_configs",
|
||||
):
|
||||
if _table_exists(table_name):
|
||||
op.drop_table(table_name)
|
||||
|
||||
_drop_column_if_exists("searchspaces", "agent_llm_id")
|
||||
_drop_column_if_exists("searchspaces", "image_generation_config_id")
|
||||
_drop_column_if_exists("searchspaces", "vision_llm_config_id")
|
||||
|
||||
_rename_column_if_exists(
|
||||
"image_generations",
|
||||
"image_generation_config_id",
|
||||
"image_gen_model_id",
|
||||
existing_type=sa.Integer(),
|
||||
)
|
||||
|
||||
op.execute("DROP TYPE IF EXISTS litellmprovider")
|
||||
op.execute("DROP TYPE IF EXISTS imagegenprovider")
|
||||
op.execute("DROP TYPE IF EXISTS visionprovider")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
bind = op.get_bind()
|
||||
litellm_provider.create(bind, checkfirst=True)
|
||||
image_gen_provider.create(bind, checkfirst=True)
|
||||
vision_provider.create(bind, checkfirst=True)
|
||||
|
||||
_rename_column_if_exists(
|
||||
"image_generations",
|
||||
"image_gen_model_id",
|
||||
"image_generation_config_id",
|
||||
existing_type=sa.Integer(),
|
||||
)
|
||||
|
||||
if _table_exists("searchspaces"):
|
||||
if not _column_exists("searchspaces", "agent_llm_id"):
|
||||
op.add_column(
|
||||
"searchspaces",
|
||||
sa.Column("agent_llm_id", sa.Integer(), nullable=True),
|
||||
)
|
||||
if not _column_exists("searchspaces", "image_generation_config_id"):
|
||||
op.add_column(
|
||||
"searchspaces",
|
||||
sa.Column("image_generation_config_id", sa.Integer(), nullable=True),
|
||||
)
|
||||
if not _column_exists("searchspaces", "vision_llm_config_id"):
|
||||
op.add_column(
|
||||
"searchspaces",
|
||||
sa.Column("vision_llm_config_id", sa.Integer(), nullable=True),
|
||||
)
|
||||
|
||||
if not _table_exists("image_generation_configs"):
|
||||
op.create_table(
|
||||
"image_generation_configs",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("name", sa.String(length=100), nullable=False),
|
||||
sa.Column("description", sa.String(length=500), nullable=True),
|
||||
sa.Column("provider", image_gen_provider, nullable=False),
|
||||
sa.Column("custom_provider", sa.String(length=100), nullable=True),
|
||||
sa.Column("model_name", sa.String(length=100), nullable=False),
|
||||
sa.Column("api_key", sa.String(), nullable=False),
|
||||
sa.Column("api_base", sa.String(length=500), nullable=True),
|
||||
sa.Column("api_version", sa.String(length=50), nullable=True),
|
||||
sa.Column("litellm_params", sa.JSON(), nullable=True),
|
||||
sa.Column("search_space_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["search_space_id"], ["searchspaces.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_image_generation_configs_name"),
|
||||
"image_generation_configs",
|
||||
["name"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
if not _table_exists("vision_llm_configs"):
|
||||
op.create_table(
|
||||
"vision_llm_configs",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("name", sa.String(length=100), nullable=False),
|
||||
sa.Column("description", sa.String(length=500), nullable=True),
|
||||
sa.Column("provider", vision_provider, nullable=False),
|
||||
sa.Column("custom_provider", sa.String(length=100), nullable=True),
|
||||
sa.Column("model_name", sa.String(length=100), nullable=False),
|
||||
sa.Column("api_key", sa.String(), nullable=False),
|
||||
sa.Column("api_base", sa.String(length=500), nullable=True),
|
||||
sa.Column("api_version", sa.String(length=50), nullable=True),
|
||||
sa.Column("litellm_params", sa.JSON(), nullable=True),
|
||||
sa.Column("search_space_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["search_space_id"], ["searchspaces.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_vision_llm_configs_name"),
|
||||
"vision_llm_configs",
|
||||
["name"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
if not _table_exists("new_llm_configs"):
|
||||
op.create_table(
|
||||
"new_llm_configs",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("name", sa.String(length=100), nullable=False),
|
||||
sa.Column("description", sa.String(length=500), nullable=True),
|
||||
sa.Column("provider", litellm_provider, nullable=False),
|
||||
sa.Column("custom_provider", sa.String(length=100), nullable=True),
|
||||
sa.Column("model_name", sa.String(length=100), nullable=False),
|
||||
sa.Column("api_key", sa.String(), nullable=False),
|
||||
sa.Column("api_base", sa.String(length=500), nullable=True),
|
||||
sa.Column("litellm_params", sa.JSON(), nullable=True),
|
||||
sa.Column("system_instructions", sa.Text(), nullable=False),
|
||||
sa.Column("use_default_system_instructions", sa.Boolean(), nullable=False),
|
||||
sa.Column("citations_enabled", sa.Boolean(), nullable=False),
|
||||
sa.Column("search_space_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["search_space_id"], ["searchspaces.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_new_llm_configs_name"),
|
||||
"new_llm_configs",
|
||||
["name"],
|
||||
unique=False,
|
||||
)
|
||||
|
|
@ -1,15 +1,15 @@
|
|||
"""add etl_cache_parses table for content-addressed parse reuse
|
||||
|
||||
Revision ID: 160
|
||||
Revises: 159
|
||||
Revision ID: 162
|
||||
Revises: 161
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "160"
|
||||
down_revision: str | None = "159"
|
||||
revision: str = "162"
|
||||
down_revision: str | None = "161"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
|
@ -1,15 +1,15 @@
|
|||
"""add embedding_cache_sets table for content-addressed embedding reuse
|
||||
|
||||
Revision ID: 161
|
||||
Revises: 160
|
||||
Revision ID: 163
|
||||
Revises: 162
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "161"
|
||||
down_revision: str | None = "160"
|
||||
revision: str = "163"
|
||||
down_revision: str | None = "162"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
|
@ -3,16 +3,16 @@
|
|||
Incremental re-indexing keeps unchanged chunk rows, so auto-increment ids no
|
||||
longer reflect document order. Backfill preserves the historical id ordering.
|
||||
|
||||
Revision ID: 162
|
||||
Revises: 161
|
||||
Revision ID: 164
|
||||
Revises: 163
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "162"
|
||||
down_revision: str | None = "161"
|
||||
revision: str = "164"
|
||||
down_revision: str | None = "163"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
|
@ -57,7 +57,7 @@ async def build_agent_with_cache(
|
|||
mcp_tools_by_agent: dict[str, list[BaseTool]],
|
||||
disabled_tools: list[str] | None,
|
||||
config_id: str | None,
|
||||
image_generation_config_id_override: int | None = None,
|
||||
image_gen_model_id_override: int | None = None,
|
||||
) -> Any:
|
||||
"""Compile the multi-agent graph, serving from cache when key components are stable."""
|
||||
|
||||
|
|
@ -121,7 +121,7 @@ async def build_agent_with_cache(
|
|||
# Bound into the generate_image subagent tool at construction time, so it
|
||||
# must key the compiled-agent cache to avoid leaking one automation's
|
||||
# image model into another with the same config_id/search_space.
|
||||
image_generation_config_id_override,
|
||||
image_gen_model_id_override,
|
||||
)
|
||||
return await get_cache().get_or_build(cache_key, builder=_build)
|
||||
|
||||
|
|
|
|||
|
|
@ -72,11 +72,11 @@ async def create_multi_agent_chat_deep_agent(
|
|||
mentioned_document_ids: list[int] | None = None,
|
||||
anon_session_id: str | None = None,
|
||||
filesystem_selection: FilesystemSelection | None = None,
|
||||
image_generation_config_id: int | None = None,
|
||||
image_gen_model_id: int | None = None,
|
||||
):
|
||||
"""Deep agent with SurfSense tools/middleware; registry route subagents behind ``task`` when enabled.
|
||||
|
||||
``image_generation_config_id`` overrides the search space's image model for
|
||||
``image_gen_model_id`` overrides the search space's image model for
|
||||
this invocation (used by automations to run on their captured model). When
|
||||
``None``, the ``generate_image`` tool resolves the live search-space pref.
|
||||
"""
|
||||
|
|
@ -147,7 +147,7 @@ async def create_multi_agent_chat_deep_agent(
|
|||
"llm": llm,
|
||||
# Per-invocation image model override (automations run on their captured
|
||||
# model). Reaches the generate_image subagent tool via subagent_dependencies.
|
||||
"image_generation_config_id_override": image_generation_config_id,
|
||||
"image_gen_model_id_override": image_gen_model_id,
|
||||
}
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
|
|
@ -303,7 +303,7 @@ async def create_multi_agent_chat_deep_agent(
|
|||
mcp_tools_by_agent=mcp_tools_by_agent,
|
||||
disabled_tools=disabled_tools,
|
||||
config_id=config_id,
|
||||
image_generation_config_id_override=image_generation_config_id,
|
||||
image_gen_model_id_override=image_gen_model_id,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[create_agent] Middleware stack + graph compiled in %.3fs",
|
||||
|
|
|
|||
|
|
@ -10,70 +10,53 @@ from langgraph.types import Command
|
|||
from litellm import aimage_generation
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.agents.chat.multi_agent_chat.shared.receipts.command import with_receipt
|
||||
from app.agents.chat.multi_agent_chat.shared.receipts.receipt import make_receipt
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
ImageGeneration,
|
||||
ImageGenerationConfig,
|
||||
Model,
|
||||
SearchSpace,
|
||||
shielded_async_session,
|
||||
)
|
||||
from app.services.auto_model_pin_service import (
|
||||
auto_model_candidates,
|
||||
choose_auto_model_candidate,
|
||||
)
|
||||
from app.services.image_gen_router_service import (
|
||||
IMAGE_GEN_AUTO_MODE_ID,
|
||||
ImageGenRouterService,
|
||||
is_image_gen_auto_mode,
|
||||
)
|
||||
from app.services.provider_api_base import resolve_api_base
|
||||
from app.services.model_capabilities import has_capability
|
||||
from app.services.model_resolver import to_litellm
|
||||
from app.utils.signed_image_urls import generate_image_token
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Provider mapping (same as routes)
|
||||
_PROVIDER_MAP = {
|
||||
"OPENAI": "openai",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"GOOGLE": "gemini",
|
||||
"VERTEX_AI": "vertex_ai",
|
||||
"BEDROCK": "bedrock",
|
||||
"RECRAFT": "recraft",
|
||||
"OPENROUTER": "openrouter",
|
||||
"XINFERENCE": "xinference",
|
||||
"NSCALE": "nscale",
|
||||
}
|
||||
|
||||
def _get_global_model(model_id: int) -> dict | None:
|
||||
return next((m for m in config.GLOBAL_MODELS if m.get("id") == model_id), None)
|
||||
|
||||
|
||||
def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str:
|
||||
if custom_provider:
|
||||
return custom_provider
|
||||
return _PROVIDER_MAP.get(provider.upper(), provider.lower())
|
||||
|
||||
|
||||
def _build_model_string(
|
||||
provider: str, model_name: str, custom_provider: str | None
|
||||
) -> str:
|
||||
return f"{_resolve_provider_prefix(provider, custom_provider)}/{model_name}"
|
||||
|
||||
|
||||
def _get_global_image_gen_config(config_id: int) -> dict | None:
|
||||
"""Get a global image gen config by negative ID."""
|
||||
for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS:
|
||||
if cfg.get("id") == config_id:
|
||||
return cfg
|
||||
return None
|
||||
def _get_global_connection(connection_id: int) -> dict | None:
|
||||
return next(
|
||||
(c for c in config.GLOBAL_CONNECTIONS if c.get("id") == connection_id),
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def create_generate_image_tool(
|
||||
search_space_id: int,
|
||||
db_session: AsyncSession,
|
||||
image_generation_config_id_override: int | None = None,
|
||||
image_gen_model_id_override: int | None = None,
|
||||
):
|
||||
"""Create ``generate_image`` with bound search space; DB work uses a per-call session.
|
||||
|
||||
``image_generation_config_id_override``: when set (automations running on a
|
||||
captured model), use this config id instead of reading the search space's
|
||||
live ``image_generation_config_id``.
|
||||
``image_gen_model_id_override``: when set (automations running on a
|
||||
captured model), use this model id instead of reading the search space's
|
||||
live ``image_gen_model_id``.
|
||||
"""
|
||||
del db_session # tool uses a fresh per-call session instead
|
||||
|
||||
|
|
@ -118,26 +101,23 @@ def create_generate_image_tool(
|
|||
# task's session is shared across every tool; without isolation,
|
||||
# autoflushes from a concurrent writer poison this tool too.
|
||||
async with shielded_async_session() as session:
|
||||
if image_generation_config_id_override is not None:
|
||||
result = await session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == search_space_id)
|
||||
)
|
||||
search_space = result.scalars().first()
|
||||
if not search_space:
|
||||
return _failed(
|
||||
{"error": "Search space not found"},
|
||||
error="Search space not found",
|
||||
)
|
||||
|
||||
if image_gen_model_id_override is not None:
|
||||
# Automation run: use the captured image model, insulated from
|
||||
# later search-space changes. No search-space read needed.
|
||||
config_id = (
|
||||
image_generation_config_id_override or IMAGE_GEN_AUTO_MODE_ID
|
||||
)
|
||||
config_id = image_gen_model_id_override or IMAGE_GEN_AUTO_MODE_ID
|
||||
else:
|
||||
result = await session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == search_space_id)
|
||||
)
|
||||
search_space = result.scalars().first()
|
||||
if not search_space:
|
||||
return _failed(
|
||||
{"error": "Search space not found"},
|
||||
error="Search space not found",
|
||||
)
|
||||
|
||||
config_id = (
|
||||
search_space.image_generation_config_id
|
||||
or IMAGE_GEN_AUTO_MODE_ID
|
||||
search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID
|
||||
)
|
||||
|
||||
# size/quality/style are intentionally omitted: valid values
|
||||
|
|
@ -147,73 +127,82 @@ def create_generate_image_tool(
|
|||
gen_kwargs["n"] = n
|
||||
|
||||
if is_image_gen_auto_mode(config_id):
|
||||
if not ImageGenRouterService.is_initialized():
|
||||
candidates = await auto_model_candidates(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
user_id=search_space.user_id,
|
||||
capability="image_gen",
|
||||
)
|
||||
if not candidates:
|
||||
err = (
|
||||
"No image generation models configured. "
|
||||
"No image generation models available. "
|
||||
"Please add an image model in Settings > Image Models."
|
||||
)
|
||||
return _failed({"error": err}, error=err)
|
||||
response = await ImageGenRouterService.aimage_generation(
|
||||
prompt=prompt, model="auto", **gen_kwargs
|
||||
config_id = int(
|
||||
choose_auto_model_candidate(candidates, search_space_id)["id"]
|
||||
)
|
||||
elif config_id < 0:
|
||||
cfg = _get_global_image_gen_config(config_id)
|
||||
if not cfg:
|
||||
err = f"Image generation config {config_id} not found"
|
||||
|
||||
if config_id < 0:
|
||||
global_model = _get_global_model(config_id)
|
||||
if not global_model or not has_capability(
|
||||
global_model, "image_gen"
|
||||
):
|
||||
err = f"Image generation model {config_id} not found"
|
||||
return _failed({"error": err}, error=err)
|
||||
global_connection = _get_global_connection(
|
||||
global_model["connection_id"]
|
||||
)
|
||||
if not global_connection:
|
||||
err = f"Image generation connection for model {config_id} not found"
|
||||
return _failed({"error": err}, error=err)
|
||||
|
||||
provider_prefix = _resolve_provider_prefix(
|
||||
cfg.get("provider", ""), cfg.get("custom_provider")
|
||||
model_string, resolved_kwargs = to_litellm(
|
||||
global_connection,
|
||||
global_model["model_id"],
|
||||
)
|
||||
model_string = f"{provider_prefix}/{cfg['model_name']}"
|
||||
gen_kwargs["api_key"] = cfg.get("api_key")
|
||||
# Defense-in-depth: an empty ``api_base`` must not fall
|
||||
# through to LiteLLM's global ``api_base`` (e.g. Azure).
|
||||
api_base = resolve_api_base(
|
||||
provider=cfg.get("provider"),
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=cfg.get("api_base"),
|
||||
)
|
||||
if api_base:
|
||||
gen_kwargs["api_base"] = api_base
|
||||
if cfg.get("api_version"):
|
||||
gen_kwargs["api_version"] = cfg["api_version"]
|
||||
if cfg.get("litellm_params"):
|
||||
gen_kwargs.update(cfg["litellm_params"])
|
||||
gen_kwargs.update(resolved_kwargs)
|
||||
|
||||
response = await aimage_generation(
|
||||
prompt=prompt, model=model_string, **gen_kwargs
|
||||
)
|
||||
else:
|
||||
# Positive ID = user-created ImageGenerationConfig
|
||||
# Positive ID = Model + Connection
|
||||
cfg_result = await session.execute(
|
||||
select(ImageGenerationConfig).filter(
|
||||
ImageGenerationConfig.id == config_id
|
||||
)
|
||||
select(Model)
|
||||
.options(selectinload(Model.connection))
|
||||
.filter(Model.id == config_id, Model.enabled.is_(True))
|
||||
)
|
||||
db_cfg = cfg_result.scalars().first()
|
||||
if not db_cfg:
|
||||
err = f"Image generation config {config_id} not found"
|
||||
db_model = cfg_result.scalars().first()
|
||||
if (
|
||||
not db_model
|
||||
or not db_model.connection
|
||||
or not db_model.connection.enabled
|
||||
):
|
||||
err = f"Image generation model {config_id} not found"
|
||||
return _failed({"error": err}, error=err)
|
||||
conn = db_model.connection
|
||||
if (
|
||||
conn.search_space_id is not None
|
||||
and conn.search_space_id != search_space_id
|
||||
):
|
||||
err = f"Image generation model {config_id} not found"
|
||||
return _failed({"error": err}, error=err)
|
||||
if (
|
||||
conn.user_id is not None
|
||||
and conn.user_id != search_space.user_id
|
||||
):
|
||||
err = f"Image generation model {config_id} not found"
|
||||
return _failed({"error": err}, error=err)
|
||||
if not has_capability(db_model, "image_gen"):
|
||||
err = f"Model {config_id} is not image-generation capable"
|
||||
return _failed({"error": err}, error=err)
|
||||
|
||||
provider_prefix = _resolve_provider_prefix(
|
||||
db_cfg.provider.value, db_cfg.custom_provider
|
||||
model_string, resolved_kwargs = to_litellm(
|
||||
db_model.connection,
|
||||
db_model.model_id,
|
||||
)
|
||||
model_string = f"{provider_prefix}/{db_cfg.model_name}"
|
||||
gen_kwargs["api_key"] = db_cfg.api_key
|
||||
# Defense-in-depth: an empty ``api_base`` must not fall
|
||||
# through to LiteLLM's global ``api_base`` (e.g. Azure).
|
||||
api_base = resolve_api_base(
|
||||
provider=db_cfg.provider.value,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=db_cfg.api_base,
|
||||
)
|
||||
if api_base:
|
||||
gen_kwargs["api_base"] = api_base
|
||||
if db_cfg.api_version:
|
||||
gen_kwargs["api_version"] = db_cfg.api_version
|
||||
if db_cfg.litellm_params:
|
||||
gen_kwargs.update(db_cfg.litellm_params)
|
||||
gen_kwargs.update(resolved_kwargs)
|
||||
|
||||
response = await aimage_generation(
|
||||
prompt=prompt, model=model_string, **gen_kwargs
|
||||
|
|
@ -230,7 +219,7 @@ def create_generate_image_tool(
|
|||
prompt=prompt,
|
||||
model=getattr(response, "_hidden_params", {}).get("model"),
|
||||
n=n,
|
||||
image_generation_config_id=config_id,
|
||||
image_gen_model_id=config_id,
|
||||
response_data=response_dict,
|
||||
search_space_id=search_space_id,
|
||||
access_token=access_token,
|
||||
|
|
|
|||
|
|
@ -51,8 +51,6 @@ def load_tools(
|
|||
create_generate_image_tool(
|
||||
search_space_id=d["search_space_id"],
|
||||
db_session=d["db_session"],
|
||||
image_generation_config_id_override=d.get(
|
||||
"image_generation_config_id_override"
|
||||
),
|
||||
image_gen_model_id_override=d.get("image_gen_model_id_override"),
|
||||
),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@
|
|||
LLM configuration utilities for SurfSense agents.
|
||||
|
||||
This module provides functions for loading LLM configurations from:
|
||||
1. Auto mode (ID 0) - Uses LiteLLM Router for load balancing
|
||||
1. Auto mode (ID 0) - Resolved by callers to a concrete model-connection model
|
||||
2. YAML files (global configs with negative IDs)
|
||||
3. Database NewLLMConfig table (user-created configs with positive IDs)
|
||||
3. Database model-connections table (user-created configs with positive IDs)
|
||||
|
||||
It also provides utilities for creating ChatLiteLLM instances and
|
||||
managing prompt configurations.
|
||||
|
|
@ -24,8 +24,6 @@ from langchain_core.messages import AIMessage, BaseMessage
|
|||
from langchain_core.outputs import ChatGenerationChunk, ChatResult
|
||||
from langchain_litellm import ChatLiteLLM
|
||||
from litellm import get_model_info
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.chat.runtime.prompt_caching import (
|
||||
apply_litellm_prompt_caching,
|
||||
|
|
@ -33,10 +31,7 @@ from app.agents.chat.runtime.prompt_caching import (
|
|||
from app.services.llm_router_service import (
|
||||
AUTO_MODE_ID,
|
||||
ChatLiteLLMRouter,
|
||||
LLMRouterService,
|
||||
_sanitize_content,
|
||||
get_auto_mode_llm,
|
||||
is_auto_mode,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -51,16 +46,19 @@ def _sanitize_messages(messages: list[BaseMessage]) -> list[BaseMessage]:
|
|||
reject the blank text. The OpenAI spec says ``content`` should be
|
||||
``null`` when an assistant message only carries tool calls.
|
||||
"""
|
||||
sanitized: list[BaseMessage] = []
|
||||
for msg in messages:
|
||||
if isinstance(msg.content, list):
|
||||
msg.content = _sanitize_content(msg.content)
|
||||
next_msg = msg.model_copy(deep=True)
|
||||
if isinstance(next_msg.content, list):
|
||||
next_msg.content = _sanitize_content(next_msg.content)
|
||||
if (
|
||||
isinstance(msg, AIMessage)
|
||||
and (not msg.content or msg.content == "")
|
||||
and getattr(msg, "tool_calls", None)
|
||||
isinstance(next_msg, AIMessage)
|
||||
and (not next_msg.content or next_msg.content == "")
|
||||
and getattr(next_msg, "tool_calls", None)
|
||||
):
|
||||
msg.content = None # type: ignore[assignment]
|
||||
return messages
|
||||
next_msg.content = None # type: ignore[assignment]
|
||||
sanitized.append(next_msg)
|
||||
return sanitized
|
||||
|
||||
|
||||
class SanitizedChatLiteLLM(ChatLiteLLM):
|
||||
|
|
@ -91,13 +89,21 @@ class SanitizedChatLiteLLM(ChatLiteLLM):
|
|||
):
|
||||
yield chunk
|
||||
|
||||
|
||||
# Re-exported under the historical name ``PROVIDER_MAP``. Source of truth lives
|
||||
# in provider_capabilities so the YAML loader can resolve prefixes during
|
||||
# app.config init without importing the agent/tools tree.
|
||||
from app.services.provider_capabilities import ( # noqa: E402
|
||||
_PROVIDER_PREFIX_MAP as PROVIDER_MAP,
|
||||
)
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||
stream: bool | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
return await super()._agenerate(
|
||||
_sanitize_messages(messages),
|
||||
stop=stop,
|
||||
run_manager=run_manager,
|
||||
stream=stream,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None:
|
||||
|
|
@ -121,8 +127,9 @@ class AgentConfig:
|
|||
"""
|
||||
Complete configuration for the SurfSense agent.
|
||||
|
||||
This combines LLM settings with prompt configuration from NewLLMConfig.
|
||||
Supports Auto mode (ID 0) which uses LiteLLM Router for load balancing.
|
||||
This combines resolved model settings with prompt configuration.
|
||||
Supports Auto mode metadata (ID 0). Runtime callers must resolve Auto to
|
||||
a concrete global or BYOK model before constructing ChatLiteLLM.
|
||||
"""
|
||||
|
||||
# LLM Model Settings
|
||||
|
|
@ -170,7 +177,7 @@ class AgentConfig:
|
|||
use_default_system_instructions=True,
|
||||
citations_enabled=True,
|
||||
config_id=AUTO_MODE_ID,
|
||||
config_name="Auto (Fastest)",
|
||||
config_name="Auto",
|
||||
is_auto_mode=True,
|
||||
billing_tier="free",
|
||||
is_premium=False,
|
||||
|
|
@ -181,64 +188,21 @@ class AgentConfig:
|
|||
supports_image_input=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_new_llm_config(cls, config) -> "AgentConfig":
|
||||
"""Build an AgentConfig from a NewLLMConfig database model."""
|
||||
# Lazy import: keeps provider_capabilities (and litellm) out of init order.
|
||||
from app.services.provider_capabilities import derive_supports_image_input
|
||||
|
||||
provider_value = (
|
||||
config.provider.value
|
||||
if hasattr(config.provider, "value")
|
||||
else str(config.provider)
|
||||
)
|
||||
litellm_params = config.litellm_params or {}
|
||||
base_model = (
|
||||
litellm_params.get("base_model")
|
||||
if isinstance(litellm_params, dict)
|
||||
else None
|
||||
)
|
||||
|
||||
return cls(
|
||||
provider=provider_value,
|
||||
model_name=config.model_name,
|
||||
api_key=config.api_key,
|
||||
api_base=config.api_base,
|
||||
custom_provider=config.custom_provider,
|
||||
litellm_params=config.litellm_params,
|
||||
system_instructions=config.system_instructions,
|
||||
use_default_system_instructions=config.use_default_system_instructions,
|
||||
citations_enabled=config.citations_enabled,
|
||||
config_id=config.id,
|
||||
config_name=config.name,
|
||||
is_auto_mode=False,
|
||||
billing_tier="free",
|
||||
is_premium=False,
|
||||
anonymous_enabled=False,
|
||||
quota_reserve_tokens=None,
|
||||
# BYOK rows have no curated flag; ask LiteLLM (default-allow on
|
||||
# unknown). The streaming safety net still blocks explicit text-only.
|
||||
supports_image_input=derive_supports_image_input(
|
||||
provider=provider_value,
|
||||
model_name=config.model_name,
|
||||
base_model=base_model,
|
||||
custom_provider=config.custom_provider,
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_yaml_config(cls, yaml_config: dict) -> "AgentConfig":
|
||||
"""Build an AgentConfig from a YAML configuration dictionary.
|
||||
|
||||
Supports the same prompt fields as NewLLMConfig (system_instructions,
|
||||
use_default_system_instructions, citations_enabled).
|
||||
Supports prompt fields such as system_instructions,
|
||||
use_default_system_instructions, and citations_enabled.
|
||||
"""
|
||||
# Lazy import: keeps provider_capabilities (and litellm) out of init order.
|
||||
from app.services.provider_capabilities import derive_supports_image_input
|
||||
|
||||
system_instructions = yaml_config.get("system_instructions", "")
|
||||
|
||||
provider = yaml_config.get("provider", "").upper()
|
||||
provider = yaml_config.get("provider") or yaml_config.get(
|
||||
"litellm_provider", ""
|
||||
)
|
||||
model_name = yaml_config.get("model_name", "")
|
||||
custom_provider = yaml_config.get("custom_provider")
|
||||
litellm_params = yaml_config.get("litellm_params") or {}
|
||||
|
|
@ -324,93 +288,15 @@ def load_global_llm_config_by_id(llm_config_id: int) -> dict | None:
|
|||
return load_llm_config_from_yaml(llm_config_id)
|
||||
|
||||
|
||||
async def load_new_llm_config_from_db(
|
||||
session: AsyncSession,
|
||||
config_id: int,
|
||||
) -> "AgentConfig | None":
|
||||
"""Load a NewLLMConfig from the database by ID."""
|
||||
from app.db import NewLLMConfig
|
||||
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(NewLLMConfig).filter(NewLLMConfig.id == config_id)
|
||||
)
|
||||
config = result.scalars().first()
|
||||
|
||||
if not config:
|
||||
print(f"Error: NewLLMConfig with id {config_id} not found")
|
||||
return None
|
||||
|
||||
return AgentConfig.from_new_llm_config(config)
|
||||
except Exception as e:
|
||||
print(f"Error loading NewLLMConfig from database: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def load_agent_llm_config_for_search_space(
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
) -> "AgentConfig | None":
|
||||
"""Load the agent LLM config for a search space via its agent_llm_id.
|
||||
|
||||
Positive id -> DB; negative -> YAML; None -> first global config (-1).
|
||||
"""
|
||||
from app.db import SearchSpace
|
||||
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == search_space_id)
|
||||
)
|
||||
search_space = result.scalars().first()
|
||||
|
||||
if not search_space:
|
||||
print(f"Error: SearchSpace with id {search_space_id} not found")
|
||||
return None
|
||||
|
||||
config_id = (
|
||||
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
|
||||
)
|
||||
return await load_agent_config(session, config_id, search_space_id)
|
||||
except Exception as e:
|
||||
print(f"Error loading agent LLM config for search space {search_space_id}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def load_agent_config(
|
||||
session: AsyncSession,
|
||||
config_id: int,
|
||||
search_space_id: int | None = None,
|
||||
) -> "AgentConfig | None":
|
||||
"""Main config loader: id 0 -> Auto mode; negative -> YAML; positive -> DB."""
|
||||
if is_auto_mode(config_id):
|
||||
if not LLMRouterService.is_initialized():
|
||||
print("Error: Auto mode requested but LLM Router not initialized")
|
||||
return None
|
||||
return AgentConfig.from_auto_mode()
|
||||
|
||||
if config_id < 0:
|
||||
# In-memory covers static YAML + dynamic OpenRouter configs.
|
||||
from app.config import config as app_config
|
||||
|
||||
for cfg in app_config.GLOBAL_LLM_CONFIGS:
|
||||
if cfg.get("id") == config_id:
|
||||
return AgentConfig.from_yaml_config(cfg)
|
||||
yaml_config = load_llm_config_from_yaml(config_id)
|
||||
if yaml_config:
|
||||
return AgentConfig.from_yaml_config(yaml_config)
|
||||
return None
|
||||
else:
|
||||
return await load_new_llm_config_from_db(session, config_id)
|
||||
|
||||
|
||||
def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
|
||||
"""Create a ChatLiteLLM instance from a global LLM config dictionary."""
|
||||
if llm_config.get("custom_provider"):
|
||||
model_string = f"{llm_config['custom_provider']}/{llm_config['model_name']}"
|
||||
else:
|
||||
provider = llm_config.get("provider", "").upper()
|
||||
provider_prefix = PROVIDER_MAP.get(provider, provider.lower())
|
||||
model_string = f"{provider_prefix}/{llm_config['model_name']}"
|
||||
provider = llm_config.get("provider") or llm_config.get(
|
||||
"litellm_provider", "openai"
|
||||
)
|
||||
model_string = f"{provider}/{llm_config['model_name']}"
|
||||
|
||||
litellm_kwargs = {
|
||||
"model": model_string,
|
||||
|
|
@ -433,29 +319,17 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
|
|||
def create_chat_litellm_from_agent_config(
|
||||
agent_config: AgentConfig,
|
||||
) -> ChatLiteLLM | ChatLiteLLMRouter | None:
|
||||
"""Create a ChatLiteLLM (or, for Auto mode, a load-balancing router) from config."""
|
||||
"""Create a ChatLiteLLM from an already resolved concrete model config."""
|
||||
if agent_config.is_auto_mode:
|
||||
if not LLMRouterService.is_initialized():
|
||||
print("Error: Auto mode requested but LLM Router not initialized")
|
||||
return None
|
||||
try:
|
||||
router_llm = get_auto_mode_llm()
|
||||
if router_llm is not None:
|
||||
# Universal injection points only: auto-mode fans out across
|
||||
# providers, so provider-specific kwargs have no known target.
|
||||
apply_litellm_prompt_caching(router_llm, agent_config=agent_config)
|
||||
return router_llm
|
||||
except Exception as e:
|
||||
print(f"Error creating ChatLiteLLMRouter: {e}")
|
||||
return None
|
||||
print(
|
||||
"Error: Auto mode must be resolved to a concrete model before LLM creation"
|
||||
)
|
||||
return None
|
||||
|
||||
if agent_config.custom_provider:
|
||||
model_string = f"{agent_config.custom_provider}/{agent_config.model_name}"
|
||||
else:
|
||||
provider_prefix = PROVIDER_MAP.get(
|
||||
agent_config.provider, agent_config.provider.lower()
|
||||
)
|
||||
model_string = f"{provider_prefix}/{agent_config.model_name}"
|
||||
model_string = f"{agent_config.provider}/{agent_config.model_name}"
|
||||
|
||||
litellm_kwargs = {
|
||||
"model": model_string,
|
||||
|
|
|
|||
|
|
@ -33,7 +33,6 @@ from app.config import (
|
|||
initialize_llm_router,
|
||||
initialize_openrouter_integration,
|
||||
initialize_pricing_registration,
|
||||
initialize_vision_llm_router,
|
||||
)
|
||||
from app.db import User, create_db_and_tables, get_async_session
|
||||
from app.exceptions import GENERIC_5XX_MESSAGE, ISSUES_URL, SurfSenseError
|
||||
|
|
@ -622,7 +621,6 @@ async def lifespan(app: FastAPI):
|
|||
initialize_pricing_registration()
|
||||
initialize_llm_router()
|
||||
initialize_image_gen_router()
|
||||
initialize_vision_llm_router()
|
||||
|
||||
# Phase 1.7 — JIT warmup. Bounded so a stuck warmup never delays
|
||||
# worker readiness. ``shield`` so Uvicorn cancelling startup
|
||||
|
|
|
|||
|
|
@ -39,31 +39,31 @@ async def build_dependencies(
|
|||
*,
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
agent_llm_id: int | None = None,
|
||||
image_generation_config_id: int | None = None,
|
||||
vision_llm_config_id: int | None = None,
|
||||
chat_model_id: int | None = None,
|
||||
image_gen_model_id: int | None = None,
|
||||
vision_model_id: int | None = None,
|
||||
) -> AgentDependencies:
|
||||
"""Load the LLM bundle, connector service, and a per-invoke in-memory checkpointer.
|
||||
|
||||
Resolves the agent LLM from the automation's *captured* model snapshot
|
||||
(``agent_llm_id``) so runs are insulated from later chat/search-space model
|
||||
Resolves the chat model from the automation's *captured* model snapshot
|
||||
(``chat_model_id``) so runs are insulated from later chat/search-space model
|
||||
changes. The model policy is enforced here as a runtime backstop: a captured
|
||||
model that is no longer billable (e.g. a premium global config was removed)
|
||||
fails the run clearly instead of silently consuming a free model.
|
||||
|
||||
When ``agent_llm_id`` is ``None`` (no captured snapshot — defensive fallback),
|
||||
fall back to the live search space's ``agent_llm_id`` and validate that.
|
||||
When ``chat_model_id`` is ``None`` (no captured snapshot — defensive fallback),
|
||||
fall back to the live search space's ``chat_model_id`` and validate that.
|
||||
"""
|
||||
if agent_llm_id is not None:
|
||||
if chat_model_id is not None:
|
||||
try:
|
||||
assert_models_billable(
|
||||
agent_llm_id=agent_llm_id,
|
||||
image_generation_config_id=image_generation_config_id,
|
||||
vision_llm_config_id=vision_llm_config_id,
|
||||
chat_model_id=chat_model_id,
|
||||
image_gen_model_id=image_gen_model_id,
|
||||
vision_model_id=vision_model_id,
|
||||
)
|
||||
except AutomationModelPolicyError as exc:
|
||||
raise DependencyError(str(exc)) from exc
|
||||
resolved_agent_llm_id = agent_llm_id or 0
|
||||
resolved_chat_model_id = chat_model_id or 0
|
||||
else:
|
||||
search_space = await session.get(SearchSpace, search_space_id)
|
||||
if search_space is None:
|
||||
|
|
@ -72,15 +72,15 @@ async def build_dependencies(
|
|||
assert_automation_models_billable(search_space)
|
||||
except AutomationModelPolicyError as exc:
|
||||
raise DependencyError(str(exc)) from exc
|
||||
resolved_agent_llm_id = search_space.agent_llm_id or 0
|
||||
resolved_chat_model_id = search_space.chat_model_id or 0
|
||||
|
||||
llm, agent_config, err = await load_llm_bundle(
|
||||
session,
|
||||
config_id=resolved_agent_llm_id,
|
||||
config_id=resolved_chat_model_id,
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
if err is not None or llm is None:
|
||||
raise DependencyError(err or "failed to load agent LLM config")
|
||||
raise DependencyError(err or "failed to load chat model config")
|
||||
|
||||
connector_service, firecrawl_api_key = await setup_connector_and_firecrawl(
|
||||
session, search_space_id=search_space_id
|
||||
|
|
|
|||
|
|
@ -150,9 +150,9 @@ async def run_agent_task(
|
|||
deps = await build_dependencies(
|
||||
session=agent_session,
|
||||
search_space_id=ctx.search_space_id,
|
||||
agent_llm_id=ctx.agent_llm_id,
|
||||
image_generation_config_id=ctx.image_generation_config_id,
|
||||
vision_llm_config_id=ctx.vision_llm_config_id,
|
||||
chat_model_id=ctx.chat_model_id,
|
||||
image_gen_model_id=ctx.image_gen_model_id,
|
||||
vision_model_id=ctx.vision_model_id,
|
||||
)
|
||||
|
||||
agent = await create_multi_agent_chat_deep_agent(
|
||||
|
|
@ -167,7 +167,7 @@ async def run_agent_task(
|
|||
firecrawl_api_key=deps.firecrawl_api_key,
|
||||
thread_visibility=ChatVisibility.PRIVATE,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
image_generation_config_id=ctx.image_generation_config_id,
|
||||
image_gen_model_id=ctx.image_gen_model_id,
|
||||
)
|
||||
|
||||
agent_query, runtime_context = await _resolve_mention_context(
|
||||
|
|
|
|||
|
|
@ -23,9 +23,9 @@ class ActionContext:
|
|||
# Captured model snapshot from the automation definition (``definition.models``),
|
||||
# resolved per run instead of the live search space. ``None`` falls back to the
|
||||
# search space's current prefs (defensive; should not happen post-capture).
|
||||
agent_llm_id: int | None = None
|
||||
image_generation_config_id: int | None = None
|
||||
vision_llm_config_id: int | None = None
|
||||
chat_model_id: int | None = None
|
||||
image_gen_model_id: int | None = None
|
||||
vision_model_id: int | None = None
|
||||
|
||||
|
||||
ActionHandler = Callable[[dict[str, Any]], Awaitable[Any]]
|
||||
|
|
|
|||
|
|
@ -132,9 +132,7 @@ def _build_action_ctx(
|
|||
step_id=step.step_id,
|
||||
search_space_id=automation.search_space_id,
|
||||
creator_user_id=automation.created_by_user_id,
|
||||
agent_llm_id=models.agent_llm_id if models else None,
|
||||
image_generation_config_id=(
|
||||
models.image_generation_config_id if models else None
|
||||
),
|
||||
vision_llm_config_id=models.vision_llm_config_id if models else None,
|
||||
chat_model_id=models.chat_model_id if models else None,
|
||||
image_gen_model_id=models.image_gen_model_id if models else None,
|
||||
vision_model_id=models.vision_model_id if models else None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -14,16 +14,16 @@ from .trigger_spec import TriggerSpec
|
|||
class AutomationModels(BaseModel):
|
||||
"""Captured model profile for an automation.
|
||||
|
||||
Snapshotted from the search space's preferences at create time so runs are
|
||||
insulated from later chat/search-space model changes. Config-id conventions
|
||||
Snapshotted from the search space's model roles at create time so runs are
|
||||
insulated from later chat/search-space model changes. Model-id conventions
|
||||
match the shared scheme (``0`` Auto, ``< 0`` global, ``> 0`` BYOK).
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
agent_llm_id: int = 0
|
||||
image_generation_config_id: int = 0
|
||||
vision_llm_config_id: int = 0
|
||||
chat_model_id: int = 0
|
||||
image_gen_model_id: int = 0
|
||||
vision_model_id: int = 0
|
||||
|
||||
|
||||
class AutomationDefinition(BaseModel):
|
||||
|
|
|
|||
|
|
@ -57,9 +57,9 @@ class AutomationService:
|
|||
else:
|
||||
search_space = await self._assert_models_billable(payload.search_space_id)
|
||||
payload.definition.models = AutomationModels(
|
||||
agent_llm_id=search_space.agent_llm_id or 0,
|
||||
image_generation_config_id=search_space.image_generation_config_id or 0,
|
||||
vision_llm_config_id=search_space.vision_llm_config_id or 0,
|
||||
chat_model_id=search_space.chat_model_id or 0,
|
||||
image_gen_model_id=search_space.image_gen_model_id or 0,
|
||||
vision_model_id=search_space.vision_model_id or 0,
|
||||
)
|
||||
|
||||
automation = Automation(
|
||||
|
|
@ -225,9 +225,9 @@ class AutomationService:
|
|||
"""
|
||||
try:
|
||||
assert_models_billable(
|
||||
agent_llm_id=models.agent_llm_id,
|
||||
image_generation_config_id=models.image_generation_config_id,
|
||||
vision_llm_config_id=models.vision_llm_config_id,
|
||||
chat_model_id=models.chat_model_id,
|
||||
image_gen_model_id=models.image_gen_model_id,
|
||||
vision_model_id=models.vision_model_id,
|
||||
)
|
||||
except AutomationModelPolicyError as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
|
|
|
|||
|
|
@ -2,11 +2,11 @@
|
|||
|
||||
Automations run unattended, so every run must be **billable**: it may only use
|
||||
either a premium global model (``billing_tier == "premium"``) or a user-provided
|
||||
BYOK model (a positive config id pointing at a per-user/per-space DB row). Free
|
||||
BYOK model (a positive model id pointing at a per-user/per-space DB row). Free
|
||||
global models and Auto mode are blocked, because Auto can dispatch to a free
|
||||
deployment and free models aren't metered in premium credits.
|
||||
|
||||
Config id conventions (shared across chat / image / vision):
|
||||
Model id conventions (shared across chat / image / vision):
|
||||
- ``id == 0`` → Auto mode (``AUTO_MODE_ID`` / ``IMAGE_GEN_AUTO_MODE_ID`` /
|
||||
``VISION_AUTO_MODE_ID``). Blocked.
|
||||
- ``id < 0`` → global YAML/OpenRouter config. Allowed only if premium.
|
||||
|
|
@ -24,70 +24,45 @@ from typing import TYPE_CHECKING, Literal
|
|||
if TYPE_CHECKING:
|
||||
from app.db import SearchSpace
|
||||
|
||||
ModelKind = Literal["llm", "image", "vision"]
|
||||
ModelKind = Literal["chat", "image", "vision"]
|
||||
|
||||
_KIND_LABEL: dict[ModelKind, str] = {
|
||||
"llm": "agent LLM",
|
||||
"chat": "chat model",
|
||||
"image": "image generation model",
|
||||
"vision": "vision model",
|
||||
}
|
||||
|
||||
|
||||
def _is_premium_global(kind: ModelKind, config_id: int) -> bool:
|
||||
"""Return True if a negative (global) config id is a premium tier model."""
|
||||
def _is_premium_global(model_id: int) -> bool:
|
||||
"""Return True if a negative (global) model id is a premium tier model."""
|
||||
from app.config import config as app_config
|
||||
|
||||
cfg: dict | None = None
|
||||
if kind == "llm":
|
||||
from app.agents.chat.runtime.llm_config import (
|
||||
load_global_llm_config_by_id,
|
||||
)
|
||||
|
||||
cfg = load_global_llm_config_by_id(config_id)
|
||||
elif kind == "image":
|
||||
cfg = next(
|
||||
(
|
||||
c
|
||||
for c in app_config.GLOBAL_IMAGE_GEN_CONFIGS
|
||||
if c.get("id") == config_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
else: # vision
|
||||
cfg = next(
|
||||
(
|
||||
c
|
||||
for c in app_config.GLOBAL_VISION_LLM_CONFIGS
|
||||
if c.get("id") == config_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not cfg:
|
||||
model = next((m for m in app_config.GLOBAL_MODELS if m.get("id") == model_id), None)
|
||||
if not model:
|
||||
return False
|
||||
return str(cfg.get("billing_tier", "free")).lower() == "premium"
|
||||
return str(model.get("billing_tier", "free")).lower() == "premium"
|
||||
|
||||
|
||||
def _classify(kind: ModelKind, config_id: int | None) -> tuple[bool, str]:
|
||||
"""Classify a resolved config id as allowed or blocked.
|
||||
def _classify(kind: ModelKind, model_id: int | None) -> tuple[bool, str]:
|
||||
"""Classify a resolved model id as allowed or blocked.
|
||||
|
||||
Returns ``(allowed, reason)``; ``reason`` is empty when allowed.
|
||||
"""
|
||||
label = _KIND_LABEL[kind]
|
||||
|
||||
if config_id is None or config_id == 0:
|
||||
if model_id is None or model_id == 0:
|
||||
return (
|
||||
False,
|
||||
f"The {label} is set to Auto mode. Automations require an explicit "
|
||||
"premium model or your own (BYOK) model so every run is billable.",
|
||||
)
|
||||
|
||||
if config_id > 0:
|
||||
# Positive id → user-owned BYOK config. Always allowed.
|
||||
if model_id > 0:
|
||||
# Positive id -> user/search-space BYOK model. Always allowed.
|
||||
return True, ""
|
||||
|
||||
# Negative id → global config. Allowed only if premium.
|
||||
if _is_premium_global(kind, config_id):
|
||||
# Negative id -> global model. Allowed only if premium.
|
||||
if _is_premium_global(model_id):
|
||||
return True, ""
|
||||
|
||||
return (
|
||||
|
|
@ -99,27 +74,27 @@ def _classify(kind: ModelKind, config_id: int | None) -> tuple[bool, str]:
|
|||
|
||||
def get_model_eligibility(
|
||||
*,
|
||||
agent_llm_id: int | None,
|
||||
image_generation_config_id: int | None,
|
||||
vision_llm_config_id: int | None,
|
||||
chat_model_id: int | None,
|
||||
image_gen_model_id: int | None,
|
||||
vision_model_id: int | None,
|
||||
) -> dict:
|
||||
"""Return ``{"allowed": bool, "violations": [...]}`` for explicit config ids.
|
||||
"""Return ``{"allowed": bool, "violations": [...]}`` for explicit model ids.
|
||||
|
||||
The ID-based core shared by both the search-space path (creation/eligibility)
|
||||
and the captured-snapshot path (runtime backstop). Each violation is
|
||||
``{"kind", "config_id", "reason"}``.
|
||||
``{"kind", "model_id", "reason"}``.
|
||||
"""
|
||||
checks: list[tuple[ModelKind, int | None]] = [
|
||||
("llm", agent_llm_id),
|
||||
("image", image_generation_config_id),
|
||||
("vision", vision_llm_config_id),
|
||||
("chat", chat_model_id),
|
||||
("image", image_gen_model_id),
|
||||
("vision", vision_model_id),
|
||||
]
|
||||
|
||||
violations: list[dict] = []
|
||||
for kind, config_id in checks:
|
||||
allowed, reason = _classify(kind, config_id)
|
||||
for kind, model_id in checks:
|
||||
allowed, reason = _classify(kind, model_id)
|
||||
if not allowed:
|
||||
violations.append({"kind": kind, "config_id": config_id, "reason": reason})
|
||||
violations.append({"kind": kind, "model_id": model_id, "reason": reason})
|
||||
|
||||
return {"allowed": not violations, "violations": violations}
|
||||
|
||||
|
|
@ -131,9 +106,9 @@ def get_automation_model_eligibility(search_space: SearchSpace) -> dict:
|
|||
wrapper over :func:`get_model_eligibility`.
|
||||
"""
|
||||
return get_model_eligibility(
|
||||
agent_llm_id=search_space.agent_llm_id,
|
||||
image_generation_config_id=search_space.image_generation_config_id,
|
||||
vision_llm_config_id=search_space.vision_llm_config_id,
|
||||
chat_model_id=search_space.chat_model_id,
|
||||
image_gen_model_id=search_space.image_gen_model_id,
|
||||
vision_model_id=search_space.vision_model_id,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -150,9 +125,9 @@ class AutomationModelPolicyError(Exception):
|
|||
|
||||
def assert_models_billable(
|
||||
*,
|
||||
agent_llm_id: int | None,
|
||||
image_generation_config_id: int | None,
|
||||
vision_llm_config_id: int | None,
|
||||
chat_model_id: int | None,
|
||||
image_gen_model_id: int | None,
|
||||
vision_model_id: int | None,
|
||||
) -> None:
|
||||
"""Raise :class:`AutomationModelPolicyError` if any explicit id is not billable.
|
||||
|
||||
|
|
@ -160,9 +135,9 @@ def assert_models_billable(
|
|||
captured model snapshot.
|
||||
"""
|
||||
result = get_model_eligibility(
|
||||
agent_llm_id=agent_llm_id,
|
||||
image_generation_config_id=image_generation_config_id,
|
||||
vision_llm_config_id=vision_llm_config_id,
|
||||
chat_model_id=chat_model_id,
|
||||
image_gen_model_id=image_gen_model_id,
|
||||
vision_model_id=vision_model_id,
|
||||
)
|
||||
if not result["allowed"]:
|
||||
raise AutomationModelPolicyError(result["violations"])
|
||||
|
|
|
|||
|
|
@ -115,14 +115,12 @@ def init_worker(**kwargs):
|
|||
initialize_llm_router,
|
||||
initialize_openrouter_integration,
|
||||
initialize_pricing_registration,
|
||||
initialize_vision_llm_router,
|
||||
)
|
||||
|
||||
initialize_openrouter_integration()
|
||||
initialize_pricing_registration()
|
||||
initialize_llm_router()
|
||||
initialize_image_gen_router()
|
||||
initialize_vision_llm_router()
|
||||
|
||||
|
||||
# Celery configuration, sourced from the central Config singleton
|
||||
|
|
|
|||
|
|
@ -78,8 +78,7 @@ def load_global_llm_configs():
|
|||
# stamps) never leak into the cached YAML structure.
|
||||
configs = copy.deepcopy(data.get("global_llm_configs", []))
|
||||
|
||||
# Lazy import keeps the `app.config` -> `app.services` edge one-way
|
||||
# and matches the `provider_api_base` pattern used elsewhere.
|
||||
# Lazy import keeps the `app.config` -> `app.services` edge one-way.
|
||||
from app.services.provider_capabilities import derive_supports_image_input
|
||||
|
||||
seen_slugs: dict[str, int] = {}
|
||||
|
|
@ -104,7 +103,7 @@ def load_global_llm_configs():
|
|||
else None
|
||||
)
|
||||
cfg["supports_image_input"] = derive_supports_image_input(
|
||||
provider=cfg.get("provider"),
|
||||
provider=cfg.get("provider") or cfg.get("litellm_provider"),
|
||||
model_name=cfg.get("model_name"),
|
||||
base_model=base_model,
|
||||
custom_provider=cfg.get("custom_provider"),
|
||||
|
|
@ -120,10 +119,10 @@ def load_global_llm_configs():
|
|||
else:
|
||||
seen_slugs[slug] = cfg.get("id", 0)
|
||||
|
||||
# Stamp Auto (Fastest) ranking metadata. YAML configs are always
|
||||
# Stamp Auto ranking metadata. YAML configs are always
|
||||
# Tier A — operator-curated, locked first when premium-eligible.
|
||||
# The OpenRouter refresh tick later re-stamps health for any cfg
|
||||
# whose provider == "OPENROUTER" via _enrich_health.
|
||||
# whose provider == "openrouter" via _enrich_health.
|
||||
try:
|
||||
from app.services.quality_score import static_score_yaml
|
||||
|
||||
|
|
@ -133,7 +132,7 @@ def load_global_llm_configs():
|
|||
cfg["quality_score_static"] = static_q
|
||||
cfg["quality_score"] = static_q
|
||||
cfg["quality_score_health"] = None
|
||||
# YAML cfgs whose provider is OPENROUTER are also subject
|
||||
# YAML cfgs whose provider is openrouter are also subject
|
||||
# to health gating against their own /endpoints data — a
|
||||
# hand-picked dead OR model is still dead. _enrich_health
|
||||
# re-stamps health_gated for them on the next refresh tick.
|
||||
|
|
@ -211,42 +210,6 @@ def load_global_image_gen_configs():
|
|||
return []
|
||||
|
||||
|
||||
def load_global_vision_llm_configs():
|
||||
data = _global_config_data()
|
||||
if not data:
|
||||
return []
|
||||
|
||||
try:
|
||||
configs = copy.deepcopy(data.get("global_vision_llm_configs", []) or [])
|
||||
for cfg in configs:
|
||||
if isinstance(cfg, dict):
|
||||
cfg.setdefault("billing_tier", "free")
|
||||
return configs
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load global vision LLM configs: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def load_vision_llm_router_settings():
|
||||
default_settings = {
|
||||
"routing_strategy": "usage-based-routing",
|
||||
"num_retries": 3,
|
||||
"allowed_fails": 3,
|
||||
"cooldown_time": 60,
|
||||
}
|
||||
|
||||
data = _global_config_data()
|
||||
if not data:
|
||||
return default_settings
|
||||
|
||||
try:
|
||||
settings = data.get("vision_llm_router_settings", {})
|
||||
return {**default_settings, **settings}
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load vision LLM router settings: {e}")
|
||||
return default_settings
|
||||
|
||||
|
||||
def load_image_gen_router_settings():
|
||||
"""
|
||||
Load router settings for image generation Auto mode from YAML file.
|
||||
|
|
@ -363,8 +326,8 @@ def initialize_openrouter_integration():
|
|||
else:
|
||||
print("Info: OpenRouter integration enabled but no models fetched")
|
||||
|
||||
# Image generation + vision LLM emissions are opt-in (issue L).
|
||||
# Both reuse the catalogue already cached by ``service.initialize``
|
||||
# Image generation emissions reuse the catalogue already cached by
|
||||
# ``service.initialize``
|
||||
# so we don't make additional network calls here.
|
||||
if settings.get("image_generation_enabled"):
|
||||
try:
|
||||
|
|
@ -378,21 +341,26 @@ def initialize_openrouter_integration():
|
|||
except Exception as e:
|
||||
print(f"Warning: Failed to inject OpenRouter image-gen configs: {e}")
|
||||
|
||||
if settings.get("vision_enabled"):
|
||||
try:
|
||||
vision_configs = service.get_vision_llm_configs()
|
||||
if vision_configs:
|
||||
config.GLOBAL_VISION_LLM_CONFIGS.extend(vision_configs)
|
||||
print(
|
||||
f"Info: OpenRouter integration added {len(vision_configs)} "
|
||||
f"vision LLM models"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to inject OpenRouter vision-LLM configs: {e}")
|
||||
refresh_global_model_catalog()
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to initialize OpenRouter integration: {e}")
|
||||
|
||||
|
||||
def materialize_global_configs():
|
||||
from app.services.global_model_catalog import materialize_global_model_catalog
|
||||
|
||||
return materialize_global_model_catalog(
|
||||
chat_configs=getattr(config, "GLOBAL_LLM_CONFIGS", []),
|
||||
image_configs=getattr(config, "GLOBAL_IMAGE_GEN_CONFIGS", []),
|
||||
)
|
||||
|
||||
|
||||
def refresh_global_model_catalog():
|
||||
connections, models = materialize_global_configs()
|
||||
config.GLOBAL_CONNECTIONS = connections
|
||||
config.GLOBAL_MODELS = models
|
||||
|
||||
|
||||
def initialize_pricing_registration():
|
||||
"""
|
||||
Teach LiteLLM the per-token cost of every deployment in
|
||||
|
|
@ -430,7 +398,10 @@ def initialize_llm_router():
|
|||
router_settings = config.ROUTER_SETTINGS
|
||||
|
||||
if not all_configs:
|
||||
print("Info: No global LLM configs found, Auto mode will not be available")
|
||||
print(
|
||||
"Info: No global LLM configs found; global Auto pool is unavailable. "
|
||||
"Auto can still use enabled BYOK models."
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
|
|
@ -475,32 +446,6 @@ def initialize_image_gen_router():
|
|||
print(f"Warning: Failed to initialize Image Generation Router: {e}")
|
||||
|
||||
|
||||
def initialize_vision_llm_router():
|
||||
vision_configs = load_global_vision_llm_configs()
|
||||
# Reuse the router settings already parsed at Config construction. The
|
||||
# *configs* list is intentionally re-read from YAML (it must exclude the
|
||||
# OpenRouter-injected dynamic models held in config.GLOBAL_VISION_LLM_CONFIGS).
|
||||
router_settings = config.VISION_LLM_ROUTER_SETTINGS
|
||||
|
||||
if not vision_configs:
|
||||
print(
|
||||
"Info: No global vision LLM configs found, "
|
||||
"Vision LLM Auto mode will not be available"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
from app.services.vision_llm_router_service import VisionLLMRouterService
|
||||
|
||||
VisionLLMRouterService.initialize(vision_configs, router_settings)
|
||||
print(
|
||||
f"Info: Vision LLM Router initialized with {len(vision_configs)} models "
|
||||
f"(strategy: {router_settings.get('routing_strategy', 'usage-based-routing')})"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to initialize Vision LLM Router: {e}")
|
||||
|
||||
|
||||
class Config:
|
||||
# Check if ffmpeg is installed
|
||||
if not is_ffmpeg_installed():
|
||||
|
|
@ -762,7 +707,7 @@ class Config:
|
|||
os.getenv("QUOTA_DEFAULT_IMAGE_RESERVE_MICROS", "50000")
|
||||
)
|
||||
|
||||
# Per-podcast reservation (in micro-USD). One agent LLM call generating
|
||||
# Per-podcast reservation (in micro-USD). One chat model call generating
|
||||
# a transcript, typically 5k-20k completion tokens. $0.20 covers a long
|
||||
# premium-model run. Tune via env.
|
||||
QUOTA_DEFAULT_PODCAST_RESERVE_MICROS = int(
|
||||
|
|
@ -882,11 +827,17 @@ class Config:
|
|||
# Router settings for Image Generation Auto mode
|
||||
IMAGE_GEN_ROUTER_SETTINGS = load_image_gen_router_settings()
|
||||
|
||||
# Global Vision LLM Configurations (optional)
|
||||
GLOBAL_VISION_LLM_CONFIGS = load_global_vision_llm_configs()
|
||||
# Virtual GLOBAL connection/model catalog. This is server-only metadata
|
||||
# derived from global_llm_config.yaml; GLOBAL keys are not stored in DB.
|
||||
from app.services.global_model_catalog import (
|
||||
materialize_global_model_catalog as _materialize_global_model_catalog,
|
||||
)
|
||||
|
||||
# Router settings for Vision LLM Auto mode
|
||||
VISION_LLM_ROUTER_SETTINGS = load_vision_llm_router_settings()
|
||||
GLOBAL_CONNECTIONS, GLOBAL_MODELS = _materialize_global_model_catalog(
|
||||
chat_configs=GLOBAL_LLM_CONFIGS,
|
||||
image_configs=GLOBAL_IMAGE_GEN_CONFIGS,
|
||||
)
|
||||
del _materialize_global_model_catalog
|
||||
|
||||
# OpenRouter Integration settings (optional)
|
||||
OPENROUTER_INTEGRATION_SETTINGS = load_openrouter_integration_settings()
|
||||
|
|
|
|||
|
|
@ -1,362 +1,236 @@
|
|||
# Global LLM Configuration
|
||||
#
|
||||
# SETUP INSTRUCTIONS:
|
||||
# 1. For production: Copy this file to global_llm_config.yaml and add your real API keys
|
||||
# 2. For testing: The system will use this example file automatically if global_llm_config.yaml doesn't exist
|
||||
# 1. Copy this file to global_llm_config.yaml.
|
||||
# 2. Replace placeholder credentials, endpoints, deployment names, and pricing
|
||||
# with values from your own provider accounts.
|
||||
#
|
||||
# NOTE: The example API keys below are placeholders and won't work.
|
||||
# Replace them with your actual API keys to enable global configurations.
|
||||
# This file is intentionally safe to commit. Do not put real API keys in this
|
||||
# example file.
|
||||
#
|
||||
# These configurations will be available to all users as a convenient option
|
||||
# Users can choose to use these global configs or add their own
|
||||
# These YAML entries are materialized at startup as server-owned GLOBAL
|
||||
# connections and models:
|
||||
#
|
||||
# AUTO MODE (Recommended):
|
||||
# - Auto mode (ID: 0) uses LiteLLM Router to automatically load balance across all global configs
|
||||
# - This helps avoid rate limits by distributing requests across multiple providers
|
||||
# - New users are automatically assigned Auto mode by default
|
||||
# - Configure router_settings below to customize the load balancing behavior
|
||||
# global_llm_configs -> GLOBAL chat models
|
||||
# global_image_generation_configs -> GLOBAL image generation models
|
||||
#
|
||||
# Structure matches NewLLMConfig:
|
||||
# - Model configuration (provider, model_name, api_key, etc.)
|
||||
# - Prompt configuration (system_instructions, citations_enabled)
|
||||
# Do not add global_connections or global_models sections here. They are
|
||||
# runtime-derived metadata exposed through the model-connections APIs.
|
||||
#
|
||||
# Static config shape:
|
||||
# - Connection fields: provider, api_key, api_base, api_version
|
||||
# - Model fields: model_name, billing_tier, rpm/tpm, capabilities, litellm_params
|
||||
# - Public no-login SEO metadata: seo_title, seo_description
|
||||
# - Prompt defaults: system_instructions, use_default_system_instructions,
|
||||
# citations_enabled
|
||||
#
|
||||
# Provider notes:
|
||||
# - Use the canonical provider field.
|
||||
# - For Azure, use the bare deployment name in model_name, for example
|
||||
# model_name: "gpt-5.1". The resolver prefixes the LiteLLM model string from
|
||||
# provider: "azure".
|
||||
#
|
||||
# GLOBAL ID namespace:
|
||||
# - ID 0 is reserved for Auto mode.
|
||||
# - Negative IDs are server-owned GLOBAL models.
|
||||
# - Positive IDs are user/BYOK database models.
|
||||
# - Keep static IDs unique across chat and image generation.
|
||||
# - Suggested static ranges: chat -1..-999, image -2001..-2999.
|
||||
# - Vision is not a separate config/table. Chat models that accept images use
|
||||
# supports_image_input: true.
|
||||
#
|
||||
# COST-BASED PREMIUM CREDITS:
|
||||
# Each premium config bills the user's USD-credit balance based on the
|
||||
# actual provider cost reported by LiteLLM. For models LiteLLM already
|
||||
# knows (most OpenAI/Anthropic/etc. names) you don't need to do anything.
|
||||
# For custom Azure deployment names (e.g. an in-house "gpt-5.4" deployment)
|
||||
# or any model LiteLLM doesn't have in its built-in pricing table, declare
|
||||
# per-token costs inline so they bill correctly:
|
||||
# Each premium model bills the user's USD-credit balance based on provider cost
|
||||
# reported by LiteLLM. For custom Azure deployments or any model LiteLLM does
|
||||
# not know, declare per-token costs inline:
|
||||
#
|
||||
# litellm_params:
|
||||
# base_model: "my-custom-azure-deploy"
|
||||
# # USD per token; e.g. 0.000003 == $3.00 per million input tokens
|
||||
# input_cost_per_token: 0.000003
|
||||
# output_cost_per_token: 0.000015
|
||||
# base_model: "my-custom-deployment"
|
||||
# # USD per token; 0.00000125 == $1.25 per million input tokens.
|
||||
# input_cost_per_token: 0.00000125
|
||||
# output_cost_per_token: 0.00001
|
||||
#
|
||||
# OpenRouter dynamic models pull pricing automatically from OpenRouter's
|
||||
# API — no inline declaration needed. Models without resolvable pricing
|
||||
# debit $0 from the user's balance and log a WARNING.
|
||||
# OpenRouter dynamic chat models pull pricing automatically from OpenRouter's
|
||||
# API. Models without resolvable pricing debit $0 and log a warning.
|
||||
|
||||
# Router Settings for Auto Mode
|
||||
# These settings control how the LiteLLM Router distributes requests across models
|
||||
# =============================================================================
|
||||
# Chat Auto Mode Router Settings
|
||||
# =============================================================================
|
||||
# These settings control how the LiteLLM Router distributes Auto-mode requests
|
||||
# across curated router-eligible GLOBAL chat deployments.
|
||||
router_settings:
|
||||
# Routing strategy options:
|
||||
# - "usage-based-routing": Routes to deployment with lowest current usage (recommended for rate limits)
|
||||
# - "simple-shuffle": Random distribution with optional RPM/TPM weighting
|
||||
# - "least-busy": Routes to least busy deployment
|
||||
# - "latency-based-routing": Routes based on response latency
|
||||
# - "usage-based-routing": Routes to deployment with lowest current usage.
|
||||
# - "simple-shuffle": Random distribution with optional RPM/TPM weighting.
|
||||
# - "least-busy": Routes to least busy deployment.
|
||||
# - "latency-based-routing": Routes based on response latency.
|
||||
routing_strategy: "usage-based-routing"
|
||||
|
||||
# Number of retries before failing
|
||||
num_retries: 3
|
||||
|
||||
# Number of failures allowed before cooling down a deployment
|
||||
allowed_fails: 3
|
||||
|
||||
# Cooldown time in seconds after allowed_fails is exceeded
|
||||
cooldown_time: 60
|
||||
# Optional fallback map:
|
||||
# fallbacks:
|
||||
# - {"azure/gpt-5.1": ["azure/gpt-5.4-mini"]}
|
||||
|
||||
# Fallback models (optional) - when primary fails, try these
|
||||
# Format: [{"primary_model": ["fallback1", "fallback2"]}]
|
||||
# fallbacks: []
|
||||
|
||||
# =============================================================================
|
||||
# Static GLOBAL Chat Models
|
||||
# =============================================================================
|
||||
global_llm_configs:
|
||||
# Example: OpenAI GPT-4 Turbo with citations enabled
|
||||
# Premium Azure chat model with image input support and explicit custom
|
||||
# pricing. This is the current shape to use for hosted GPT 5.x deployments.
|
||||
- id: -1
|
||||
name: "Global GPT-4 Turbo"
|
||||
description: "OpenAI's GPT-4 Turbo with default prompts and citations"
|
||||
billing_tier: "free"
|
||||
anonymous_enabled: true
|
||||
seo_enabled: true
|
||||
seo_slug: "gpt-4-turbo"
|
||||
name: "Azure GPT 5.1"
|
||||
billing_tier: "premium"
|
||||
anonymous_enabled: false
|
||||
seo_enabled: false
|
||||
seo_slug: "azure-gpt-5-1"
|
||||
quota_reserve_tokens: 4000
|
||||
provider: "OPENAI"
|
||||
model_name: "gpt-4-turbo-preview"
|
||||
api_key: "sk-your-openai-api-key-here"
|
||||
api_base: ""
|
||||
# Rate limits for load balancing (requests/tokens per minute)
|
||||
rpm: 500 # Requests per minute
|
||||
tpm: 100000 # Tokens per minute
|
||||
provider: "azure"
|
||||
model_name: "gpt-5.1"
|
||||
supports_image_input: true
|
||||
supports_tools: true
|
||||
max_input_tokens: 400000
|
||||
api_key: "your-azure-api-key-here"
|
||||
api_base: "https://your-resource.openai.azure.com"
|
||||
# api_version is optional. Include it if your Azure deployment requires a
|
||||
# specific API version.
|
||||
# api_version: "2025-04-01-preview"
|
||||
rpm: 47500
|
||||
tpm: 14750000
|
||||
litellm_params:
|
||||
temperature: 0.7
|
||||
max_tokens: 4000
|
||||
# Prompt Configuration
|
||||
system_instructions: "" # Empty = use default SURFSENSE_SYSTEM_INSTRUCTIONS
|
||||
max_tokens: 16384
|
||||
base_model: "gpt-5.1"
|
||||
input_cost_per_token: 0.00000125
|
||||
output_cost_per_token: 0.00001
|
||||
system_instructions: ""
|
||||
use_default_system_instructions: true
|
||||
citations_enabled: true
|
||||
|
||||
# Example: Anthropic Claude 3 Opus
|
||||
# Larger premium chat model. If your provider prices long-context traffic
|
||||
# differently, choose a conservative flat price or document the limitation
|
||||
# next to the inline pricing.
|
||||
- id: -2
|
||||
name: "Global Claude 3 Opus"
|
||||
description: "Anthropic's most capable model with citations"
|
||||
billing_tier: "free"
|
||||
anonymous_enabled: true
|
||||
seo_enabled: true
|
||||
seo_slug: "claude-3-opus"
|
||||
name: "Azure GPT 5.4"
|
||||
billing_tier: "premium"
|
||||
anonymous_enabled: false
|
||||
seo_enabled: false
|
||||
seo_slug: "azure-gpt-5-4"
|
||||
quota_reserve_tokens: 4000
|
||||
provider: "ANTHROPIC"
|
||||
model_name: "claude-3-opus-20240229"
|
||||
api_key: "sk-ant-your-anthropic-api-key-here"
|
||||
api_base: ""
|
||||
rpm: 1000
|
||||
tpm: 100000
|
||||
provider: "azure"
|
||||
model_name: "gpt-5.4"
|
||||
supports_image_input: true
|
||||
supports_tools: true
|
||||
max_input_tokens: 400000
|
||||
api_key: "your-azure-api-key-here"
|
||||
api_base: "https://your-resource.openai.azure.com"
|
||||
rpm: 150000
|
||||
tpm: 15000000
|
||||
litellm_params:
|
||||
temperature: 0.7
|
||||
max_tokens: 4000
|
||||
max_tokens: 16384
|
||||
base_model: "gpt-5.4"
|
||||
input_cost_per_token: 0.0000025
|
||||
output_cost_per_token: 0.000015
|
||||
system_instructions: ""
|
||||
use_default_system_instructions: true
|
||||
citations_enabled: true
|
||||
|
||||
# Example: Fast model - GPT-3.5 Turbo (citations disabled for speed)
|
||||
# Free/no-login hosted model. Free models are visible to users when
|
||||
# anonymous_enabled/seo_enabled are true but do not debit premium credits.
|
||||
- id: -3
|
||||
name: "Global GPT-3.5 Turbo (Fast)"
|
||||
description: "Fast responses without citations for quick queries"
|
||||
name: "Azure GPT 5.4 Mini"
|
||||
billing_tier: "free"
|
||||
anonymous_enabled: true
|
||||
seo_enabled: true
|
||||
seo_slug: "gpt-3.5-turbo-fast"
|
||||
quota_reserve_tokens: 2000
|
||||
provider: "OPENAI"
|
||||
model_name: "gpt-3.5-turbo"
|
||||
api_key: "sk-your-openai-api-key-here"
|
||||
api_base: ""
|
||||
rpm: 3500 # GPT-3.5 has higher rate limits
|
||||
tpm: 200000
|
||||
litellm_params:
|
||||
temperature: 0.5
|
||||
max_tokens: 2000
|
||||
system_instructions: ""
|
||||
use_default_system_instructions: true
|
||||
citations_enabled: false # Disabled for faster responses
|
||||
|
||||
# Example: Chinese LLM - DeepSeek with custom instructions
|
||||
- id: -4
|
||||
name: "Global DeepSeek Chat (Chinese)"
|
||||
description: "DeepSeek optimized for Chinese language responses"
|
||||
billing_tier: "free"
|
||||
anonymous_enabled: true
|
||||
seo_enabled: true
|
||||
seo_slug: "deepseek-chat-chinese"
|
||||
seo_slug: "gpt-5-4-mini-no-login"
|
||||
seo_title: "Free GPT 5.4 Mini Chat"
|
||||
seo_description: "Chat with a hosted GPT 5.4 Mini model without signing in."
|
||||
quota_reserve_tokens: 4000
|
||||
provider: "DEEPSEEK"
|
||||
model_name: "deepseek-chat"
|
||||
api_key: "your-deepseek-api-key-here"
|
||||
api_base: "https://api.deepseek.com/v1"
|
||||
rpm: 60
|
||||
tpm: 100000
|
||||
litellm_params:
|
||||
temperature: 0.7
|
||||
max_tokens: 4000
|
||||
# Custom system instructions for Chinese responses
|
||||
system_instructions: |
|
||||
<system_instruction>
|
||||
You are SurfSense, a reasoning and acting AI agent designed to answer user questions using the user's personal knowledge base.
|
||||
|
||||
Today's date (UTC): {resolved_today}
|
||||
|
||||
IMPORTANT: Please respond in Chinese (简体中文) unless the user specifically requests another language.
|
||||
</system_instruction>
|
||||
use_default_system_instructions: false
|
||||
citations_enabled: true
|
||||
|
||||
# Example: Azure OpenAI GPT-4o
|
||||
# IMPORTANT: For Azure deployments, always include 'base_model' in litellm_params
|
||||
# to enable accurate token counting, cost tracking, and max token limits
|
||||
- id: -5
|
||||
name: "Global Azure GPT-4o"
|
||||
description: "Azure OpenAI GPT-4o deployment"
|
||||
billing_tier: "free"
|
||||
anonymous_enabled: true
|
||||
seo_enabled: true
|
||||
seo_slug: "azure-gpt-4o"
|
||||
quota_reserve_tokens: 4000
|
||||
provider: "AZURE"
|
||||
# model_name format for Azure: azure/<your-deployment-name>
|
||||
model_name: "azure/gpt-4o-deployment"
|
||||
provider: "azure"
|
||||
model_name: "gpt-5.4-mini"
|
||||
supports_image_input: false
|
||||
supports_tools: true
|
||||
max_input_tokens: 128000
|
||||
api_key: "your-azure-api-key-here"
|
||||
api_base: "https://your-resource.openai.azure.com"
|
||||
api_version: "2024-02-15-preview" # Azure API version
|
||||
rpm: 1000
|
||||
tpm: 150000
|
||||
rpm: 15000
|
||||
tpm: 15000000
|
||||
litellm_params:
|
||||
temperature: 0.7
|
||||
max_tokens: 4000
|
||||
# REQUIRED for Azure: Specify the underlying OpenAI model
|
||||
# This fixes "Could not identify azure model" warnings
|
||||
# Common base_model values: gpt-4, gpt-4-turbo, gpt-4o, gpt-4o-mini, gpt-3.5-turbo
|
||||
base_model: "gpt-4o"
|
||||
max_tokens: 16384
|
||||
base_model: "gpt-5.4-mini"
|
||||
system_instructions: ""
|
||||
use_default_system_instructions: true
|
||||
citations_enabled: true
|
||||
|
||||
# Example: Azure OpenAI GPT-4 Turbo
|
||||
- id: -6
|
||||
name: "Global Azure GPT-4 Turbo"
|
||||
description: "Azure OpenAI GPT-4 Turbo deployment"
|
||||
billing_tier: "free"
|
||||
anonymous_enabled: true
|
||||
seo_enabled: true
|
||||
seo_slug: "azure-gpt-4-turbo"
|
||||
quota_reserve_tokens: 4000
|
||||
provider: "AZURE"
|
||||
model_name: "azure/gpt-4-turbo-deployment"
|
||||
api_key: "your-azure-api-key-here"
|
||||
api_base: "https://your-resource.openai.azure.com"
|
||||
api_version: "2024-02-15-preview"
|
||||
rpm: 500
|
||||
tpm: 100000
|
||||
litellm_params:
|
||||
temperature: 0.7
|
||||
max_tokens: 4000
|
||||
base_model: "gpt-4-turbo" # Maps to gpt-4-turbo-preview
|
||||
system_instructions: ""
|
||||
use_default_system_instructions: true
|
||||
citations_enabled: true
|
||||
|
||||
# Example: Groq - Fast inference
|
||||
- id: -7
|
||||
name: "Global Groq Llama 3"
|
||||
description: "Ultra-fast Llama 3 70B via Groq"
|
||||
billing_tier: "free"
|
||||
anonymous_enabled: true
|
||||
seo_enabled: true
|
||||
seo_slug: "groq-llama-3"
|
||||
quota_reserve_tokens: 8000
|
||||
provider: "GROQ"
|
||||
model_name: "llama3-70b-8192"
|
||||
api_key: "your-groq-api-key-here"
|
||||
api_base: ""
|
||||
rpm: 30 # Groq has lower rate limits on free tier
|
||||
tpm: 14400
|
||||
litellm_params:
|
||||
temperature: 0.7
|
||||
max_tokens: 8000
|
||||
system_instructions: ""
|
||||
use_default_system_instructions: true
|
||||
citations_enabled: true
|
||||
|
||||
# Example: MiniMax M3 - High-performance with 512K context window
|
||||
- id: -8
|
||||
name: "Global MiniMax M3"
|
||||
description: "MiniMax M3 with 512K context window and competitive pricing"
|
||||
billing_tier: "free"
|
||||
anonymous_enabled: true
|
||||
seo_enabled: true
|
||||
seo_slug: "minimax-m3"
|
||||
quota_reserve_tokens: 4000
|
||||
provider: "MINIMAX"
|
||||
model_name: "MiniMax-M3"
|
||||
api_key: "your-minimax-api-key-here"
|
||||
api_base: "https://api.minimax.io/v1"
|
||||
rpm: 60
|
||||
tpm: 100000
|
||||
litellm_params:
|
||||
temperature: 1.0 # MiniMax requires temperature in (0.0, 1.0], cannot be 0
|
||||
max_tokens: 4000
|
||||
system_instructions: ""
|
||||
use_default_system_instructions: true
|
||||
citations_enabled: true
|
||||
|
||||
# Example: Planner LLM - small, fast model used for internal utility tasks
|
||||
#
|
||||
# The PLANNER role handles short, structured internal calls (KB query
|
||||
# rewriting, date extraction, recency classification, etc.) that don't
|
||||
# need frontier-tier capability. Pointing the planner at a cheap+fast
|
||||
# model (gpt-4o-mini, Claude Haiku, Azure gpt-5.x-nano, Groq Llama, ...)
|
||||
# typically saves 500ms-1.5s per turn vs. routing those same internal
|
||||
# calls through the user's chat model.
|
||||
#
|
||||
# Activation:
|
||||
# - Mark EXACTLY ONE global config with ``is_planner: true``.
|
||||
# - If multiple are marked, the first one wins and a WARNING is logged.
|
||||
# - If none is marked, every internal call falls back to the user's
|
||||
# chat LLM (same behavior as before this flag existed).
|
||||
#
|
||||
# This config is operator-only — it is NOT exposed in the user-facing
|
||||
# model selector, never billed against premium quota, and the
|
||||
# billing_tier / anonymous_enabled fields below are ignored.
|
||||
# Planner LLM. This is operator-only and is not shown in the user-facing
|
||||
# model selector. Only one global_llm_configs entry should set is_planner.
|
||||
- id: -9
|
||||
name: "Global Planner (GPT-4o mini)"
|
||||
description: "Internal-only planner LLM for query rewriting and classification"
|
||||
name: "Azure GPT 5.x Nano Planner"
|
||||
is_planner: true
|
||||
billing_tier: "free"
|
||||
anonymous_enabled: false
|
||||
seo_enabled: false
|
||||
quota_reserve_tokens: 1000
|
||||
provider: "OPENAI"
|
||||
model_name: "gpt-4o-mini"
|
||||
api_key: "sk-your-openai-api-key-here"
|
||||
api_base: ""
|
||||
rpm: 3500
|
||||
tpm: 200000
|
||||
provider: "azure"
|
||||
model_name: "gpt-5.4-nano"
|
||||
supports_image_input: false
|
||||
supports_tools: false
|
||||
router_pool_eligible: false
|
||||
api_key: "your-azure-api-key-here"
|
||||
api_base: "https://your-resource.openai.azure.com"
|
||||
rpm: 20000
|
||||
tpm: 4000000
|
||||
litellm_params:
|
||||
temperature: 0
|
||||
max_tokens: 1000
|
||||
base_model: "gpt-5.4-nano"
|
||||
system_instructions: ""
|
||||
use_default_system_instructions: true
|
||||
citations_enabled: false
|
||||
|
||||
# =============================================================================
|
||||
# OpenRouter Integration
|
||||
# OpenRouter Dynamic Model Integration
|
||||
# =============================================================================
|
||||
# When enabled, dynamically fetches ALL available models from the OpenRouter API
|
||||
# and injects them as global configs. This gives premium users access to any model
|
||||
# on OpenRouter (Claude, Gemini, Llama, Mistral, etc.) via their premium token quota,
|
||||
# while free-tier OpenRouter models show up with a green Free badge and do NOT
|
||||
# consume premium quota.
|
||||
# Models are fetched at startup and refreshed periodically in the background.
|
||||
# All calls go through LiteLLM with the openrouter/ prefix.
|
||||
# When enabled, SurfSense fetches the OpenRouter catalog at startup and injects
|
||||
# supported models as GLOBAL chat and optionally image-generation models.
|
||||
# Tier is derived per model from OpenRouter data:
|
||||
# - model id ends with ":free" -> billing_tier=free
|
||||
# - prompt and completion pricing are zero -> billing_tier=free
|
||||
# - otherwise -> billing_tier=premium
|
||||
#
|
||||
# Do not use deprecated openrouter_integration.billing_tier or
|
||||
# openrouter_integration.anonymous_enabled. Use the tier-specific anonymous
|
||||
# switches below.
|
||||
openrouter_integration:
|
||||
enabled: false
|
||||
api_key: "sk-or-your-openrouter-api-key"
|
||||
|
||||
# Tier is derived PER MODEL from OpenRouter's own API signals:
|
||||
# - id ends with ":free" -> billing_tier=free
|
||||
# - pricing.prompt AND pricing.completion == "0" -> billing_tier=free
|
||||
# - otherwise -> billing_tier=premium
|
||||
# No global billing_tier knob is honored; any legacy value emits a startup warning.
|
||||
|
||||
# Anonymous access is split by tier so operators can expose only free
|
||||
# models to no-login users without leaking paid inference.
|
||||
anonymous_enabled_paid: false
|
||||
anonymous_enabled_free: false
|
||||
|
||||
seo_enabled: false
|
||||
# quota_reserve_tokens: tokens reserved per call for quota enforcement
|
||||
quota_reserve_tokens: 4000
|
||||
# id_offset: base negative ID for dynamically generated configs.
|
||||
# Model IDs are derived deterministically via BLAKE2b so they survive
|
||||
# catalogue churn. Must not overlap with your static global_llm_configs IDs.
|
||||
|
||||
# Base negative ID namespace for dynamic chat models. IDs are derived
|
||||
# deterministically so they survive catalog churn. Do not overlap static IDs.
|
||||
id_offset: -10000
|
||||
# refresh_interval_hours: how often to re-fetch models from OpenRouter (0 = startup only)
|
||||
|
||||
# Separate base negative ID namespace for dynamic image-generation models.
|
||||
image_id_offset: -20000
|
||||
|
||||
# How often to refresh the OpenRouter catalog. 0 means startup only.
|
||||
refresh_interval_hours: 24
|
||||
|
||||
# Rate limits for PAID OpenRouter models. These are used by LiteLLM Router
|
||||
# for per-deployment accounting when OR premium models participate in the
|
||||
# shared sub-agent "auto" pool. They do NOT cap OpenRouter itself — your
|
||||
# real account limits live at https://openrouter.ai/settings/limits.
|
||||
# Paid OpenRouter models may join curated router pools when eligible.
|
||||
rpm: 200
|
||||
tpm: 1000000
|
||||
|
||||
# Rate limits for FREE OpenRouter models. Informational only: free OR
|
||||
# models are intentionally kept OUT of the LiteLLM Router pool, because
|
||||
# OpenRouter enforces free-tier limits globally per account (~20 RPM +
|
||||
# 50-1000 daily requests across every ":free" model combined) —
|
||||
# per-deployment router accounting can't represent a shared bucket
|
||||
# correctly. Free OR models stay fully available in the model selector
|
||||
# and for user-facing Auto thread pinning.
|
||||
# Free OpenRouter models are available for user-facing selection/pinning but
|
||||
# should be treated as a shared-account bucket, not normal router capacity.
|
||||
free_rpm: 20
|
||||
free_tpm: 100000
|
||||
|
||||
# Image generation + vision LLM emission are OPT-IN. OpenRouter's catalogue
|
||||
# contains hundreds of image- and vision-capable models; turning these on
|
||||
# injects them into the global Image-Generation / Vision-LLM model
|
||||
# selectors alongside any static configs. Tier (free/premium) is derived
|
||||
# per model the same way it is for chat (`:free` suffix or zero pricing).
|
||||
# When a user picks a premium image/vision model the call debits the
|
||||
# shared $5 USD-cost-based premium credit pool — so leaving these off
|
||||
# avoids surprise quota burn on existing deployments. Default: false.
|
||||
# Image generation is opt-in to avoid injecting a large image catalog during
|
||||
# upgrades. Vision-capable chat models are represented with
|
||||
# supports_image_input: true.
|
||||
image_generation_enabled: false
|
||||
vision_enabled: false
|
||||
|
||||
|
|
@ -367,191 +241,80 @@ openrouter_integration:
|
|||
citations_enabled: true
|
||||
|
||||
# =============================================================================
|
||||
# Image Generation Configuration
|
||||
# Image Generation Auto Mode Router Settings
|
||||
# =============================================================================
|
||||
# These configurations power the image generation feature using litellm.aimage_generation().
|
||||
# Supported providers: OpenAI, Azure, Google AI Studio, Vertex AI, AWS Bedrock,
|
||||
# Recraft, OpenRouter, Xinference, Nscale
|
||||
#
|
||||
# Auto mode (ID 0) uses LiteLLM Router for load balancing across all image gen configs.
|
||||
|
||||
# Router Settings for Image Generation Auto Mode
|
||||
image_generation_router_settings:
|
||||
routing_strategy: "usage-based-routing"
|
||||
num_retries: 3
|
||||
allowed_fails: 3
|
||||
cooldown_time: 60
|
||||
|
||||
# =============================================================================
|
||||
# Static GLOBAL Image Generation Models
|
||||
# =============================================================================
|
||||
global_image_generation_configs:
|
||||
# Example: OpenAI DALL-E 3
|
||||
- id: -1
|
||||
name: "Global DALL-E 3"
|
||||
description: "OpenAI's DALL-E 3 for high-quality image generation"
|
||||
provider: "OPENAI"
|
||||
model_name: "dall-e-3"
|
||||
api_key: "sk-your-openai-api-key-here"
|
||||
api_base: ""
|
||||
rpm: 50 # Requests per minute (image gen is rate-limited by RPM, not tokens)
|
||||
litellm_params: {}
|
||||
|
||||
# Example: OpenAI GPT Image 1
|
||||
- id: -2
|
||||
name: "Global GPT Image 1"
|
||||
description: "OpenAI's GPT Image 1 model"
|
||||
provider: "OPENAI"
|
||||
model_name: "gpt-image-1"
|
||||
api_key: "sk-your-openai-api-key-here"
|
||||
api_base: ""
|
||||
rpm: 50
|
||||
litellm_params: {}
|
||||
|
||||
# Example: Azure OpenAI DALL-E 3
|
||||
- id: -3
|
||||
name: "Global Azure DALL-E 3"
|
||||
description: "Azure-hosted DALL-E 3 deployment"
|
||||
provider: "AZURE_OPENAI"
|
||||
model_name: "azure/dall-e-3-deployment"
|
||||
- id: -2001
|
||||
name: "Azure GPT Image 1.5"
|
||||
billing_tier: "premium"
|
||||
provider: "azure"
|
||||
model_name: "gpt-image-1.5"
|
||||
api_key: "your-azure-api-key-here"
|
||||
api_base: "https://your-resource.openai.azure.com"
|
||||
api_version: "2024-02-15-preview"
|
||||
rpm: 50
|
||||
# api_version: "2025-04-01-preview"
|
||||
rpm: 60
|
||||
litellm_params:
|
||||
base_model: "dall-e-3"
|
||||
base_model: "gpt-image-1.5"
|
||||
|
||||
# Example: OpenRouter Gemini Image Generation
|
||||
# - id: -4
|
||||
# name: "Global Gemini Image Gen"
|
||||
# description: "Google Gemini image generation via OpenRouter"
|
||||
# provider: "OPENROUTER"
|
||||
# model_name: "google/gemini-2.5-flash-image"
|
||||
# api_key: "your-openrouter-api-key-here"
|
||||
# api_base: ""
|
||||
# rpm: 30
|
||||
# litellm_params: {}
|
||||
- id: -2002
|
||||
name: "Azure GPT Image 1 Mini"
|
||||
billing_tier: "free"
|
||||
provider: "azure"
|
||||
model_name: "gpt-image-1-mini"
|
||||
api_key: "your-azure-api-key-here"
|
||||
api_base: "https://your-resource.openai.azure.com"
|
||||
# api_version: "2025-04-01-preview"
|
||||
rpm: 120
|
||||
litellm_params:
|
||||
base_model: "gpt-image-1-mini"
|
||||
|
||||
# =============================================================================
|
||||
# Vision LLM Configuration
|
||||
# Field Notes
|
||||
# =============================================================================
|
||||
# These configurations power the vision autocomplete feature (screenshot analysis).
|
||||
# Only vision-capable models should be used here (e.g. GPT-4o, Gemini Pro, Claude 3).
|
||||
# Supported providers: OpenAI, Anthropic, Google, Azure OpenAI, Vertex AI, Bedrock,
|
||||
# xAI, OpenRouter, Ollama, Groq, Together AI, Fireworks AI, DeepSeek, Mistral, Custom
|
||||
# Common chat/image fields:
|
||||
# - provider: Canonical provider adapter name. Example: azure, openai,
|
||||
# anthropic, openrouter, groq, bedrock.
|
||||
# - model_name: Provider model or deployment id. For Azure, use the bare
|
||||
# deployment name. The resolver prefixes LiteLLM model strings from provider.
|
||||
# - api_base: Provider endpoint/root URL. For OpenAI-compatible providers, the
|
||||
# resolver adds /v1 when needed.
|
||||
# - api_version: Optional provider-specific API version, stored on the
|
||||
# materialized connection extra metadata.
|
||||
# - litellm_params: Passed to LiteLLM when invoking the model. Also used for
|
||||
# base_model and inline pricing registration.
|
||||
#
|
||||
# Auto mode (ID 0) uses LiteLLM Router for load balancing across all vision configs.
|
||||
|
||||
# Router Settings for Vision LLM Auto Mode
|
||||
vision_llm_router_settings:
|
||||
routing_strategy: "usage-based-routing"
|
||||
num_retries: 3
|
||||
allowed_fails: 3
|
||||
cooldown_time: 60
|
||||
|
||||
global_vision_llm_configs:
|
||||
# Example: OpenAI GPT-4o (recommended for vision)
|
||||
- id: -1
|
||||
name: "Global GPT-4o Vision"
|
||||
description: "OpenAI's GPT-4o with strong vision capabilities"
|
||||
provider: "OPENAI"
|
||||
model_name: "gpt-4o"
|
||||
api_key: "sk-your-openai-api-key-here"
|
||||
api_base: ""
|
||||
rpm: 500
|
||||
tpm: 100000
|
||||
litellm_params:
|
||||
temperature: 0.3
|
||||
max_tokens: 1000
|
||||
|
||||
# Example: Google Gemini 2.0 Flash
|
||||
- id: -2
|
||||
name: "Global Gemini 2.0 Flash"
|
||||
description: "Google's fast vision model with large context"
|
||||
provider: "GOOGLE"
|
||||
model_name: "gemini-2.0-flash"
|
||||
api_key: "your-google-ai-api-key-here"
|
||||
api_base: ""
|
||||
rpm: 1000
|
||||
tpm: 200000
|
||||
litellm_params:
|
||||
temperature: 0.3
|
||||
max_tokens: 1000
|
||||
|
||||
# Example: Anthropic Claude 3.5 Sonnet
|
||||
- id: -3
|
||||
name: "Global Claude 3.5 Sonnet Vision"
|
||||
description: "Anthropic's Claude 3.5 Sonnet with vision support"
|
||||
provider: "ANTHROPIC"
|
||||
model_name: "claude-3-5-sonnet-20241022"
|
||||
api_key: "sk-ant-your-anthropic-api-key-here"
|
||||
api_base: ""
|
||||
rpm: 1000
|
||||
tpm: 100000
|
||||
litellm_params:
|
||||
temperature: 0.3
|
||||
max_tokens: 1000
|
||||
|
||||
# Example: Azure OpenAI GPT-4o
|
||||
# - id: -4
|
||||
# name: "Global Azure GPT-4o Vision"
|
||||
# description: "Azure-hosted GPT-4o for vision analysis"
|
||||
# provider: "AZURE_OPENAI"
|
||||
# model_name: "azure/gpt-4o-deployment"
|
||||
# api_key: "your-azure-api-key-here"
|
||||
# api_base: "https://your-resource.openai.azure.com"
|
||||
# api_version: "2024-02-15-preview"
|
||||
# rpm: 500
|
||||
# tpm: 100000
|
||||
# litellm_params:
|
||||
# temperature: 0.3
|
||||
# max_tokens: 1000
|
||||
# base_model: "gpt-4o"
|
||||
|
||||
# Notes:
|
||||
# - ID 0 is reserved for "Auto" mode - uses LiteLLM Router for load balancing
|
||||
# - Use negative IDs to distinguish global configs from user configs (NewLLMConfig in DB)
|
||||
# - IDs should be unique and sequential (e.g., -1, -2, -3, etc.)
|
||||
# - The 'api_key' field will not be exposed to users via API
|
||||
# - system_instructions: Custom prompt or empty string to use defaults
|
||||
# - use_default_system_instructions: true = use SURFSENSE_SYSTEM_INSTRUCTIONS when system_instructions is empty
|
||||
# - citations_enabled: true = include citation instructions, false = include anti-citation instructions
|
||||
# - All standard LiteLLM providers are supported
|
||||
# - rpm/tpm: Optional rate limits for load balancing (requests/tokens per minute)
|
||||
# These help the router distribute load evenly and avoid rate limit errors
|
||||
# Chat model fields:
|
||||
# - supports_image_input: true when the chat model can consume image inputs.
|
||||
# - supports_tools: true when the model can use tools/function calling.
|
||||
# - max_input_tokens: Optional UI/catalog metadata for context size.
|
||||
# - router_pool_eligible: false keeps a model out of shared router pools while
|
||||
# still allowing direct selection/pinning.
|
||||
# - is_planner: true marks the internal-only planner model. Only one config
|
||||
# should set this flag.
|
||||
#
|
||||
# Catalog and access fields:
|
||||
# - billing_tier: "free" or "premium".
|
||||
# - anonymous_enabled: Whether the model appears in the public no-login catalog.
|
||||
# - seo_enabled: Whether a /free/<seo_slug> landing page is generated.
|
||||
# - seo_slug: Stable URL slug for SEO pages. Keep unique and do not change once
|
||||
# public.
|
||||
# - seo_title / seo_description: Optional SEO metadata overrides.
|
||||
# - quota_reserve_tokens: Tokens reserved before each chat LLM call.
|
||||
# - rpm / tpm: Optional rate limits for router accounting and load balancing.
|
||||
#
|
||||
# IMAGE GENERATION NOTES:
|
||||
# - Image generation configs use the same ID scheme as LLM configs (negative for global)
|
||||
# - Supported models: dall-e-2, dall-e-3, gpt-image-1 (OpenAI), azure/* (Azure),
|
||||
# bedrock/* (AWS), vertex_ai/* (Google), recraft/* (Recraft), openrouter/* (OpenRouter)
|
||||
# - The router uses litellm.aimage_generation() for async image generation
|
||||
# - Only RPM (requests per minute) is relevant for image generation rate limiting.
|
||||
# TPM (tokens per minute) does not apply since image APIs are billed/rate-limited per request, not per token.
|
||||
#
|
||||
# VISION LLM NOTES:
|
||||
# - Vision configs use the same ID scheme (negative for global, positive for user DB)
|
||||
# - Only use vision-capable models (GPT-4o, Gemini, Claude 3, etc.)
|
||||
# - Lower temperature (0.3) is recommended for accurate screenshot analysis
|
||||
# - Lower max_tokens (1000) is sufficient since autocomplete produces short suggestions
|
||||
#
|
||||
# PLANNER LLM NOTES:
|
||||
# - is_planner: true marks a config as the internal-only planner LLM (small,
|
||||
# fast model used for KB query rewriting, date extraction, recency
|
||||
# classification, etc.). Only one config may carry this flag — if
|
||||
# multiple do, the first one wins and a startup WARNING is logged.
|
||||
# - When no config is marked is_planner, every internal utility call falls
|
||||
# back to the user's chat LLM (the historical behavior).
|
||||
# - Planner configs are NOT shown in the user-facing model selector and
|
||||
# are NOT billed against the user's premium quota. Their billing_tier,
|
||||
# anonymous_enabled, seo_* fields are ignored.
|
||||
# - Recommended models: gpt-4o-mini, claude-3-5-haiku, gemini-1.5-flash,
|
||||
# azure gpt-5.x-nano, groq llama3-8b — anything <200ms p50 on a 1-2k
|
||||
# prompt. Frontier models here defeat the purpose of the flag.
|
||||
#
|
||||
# TOKEN QUOTA & ANONYMOUS ACCESS NOTES:
|
||||
# - billing_tier: "free" or "premium". Controls whether registered users need premium token quota.
|
||||
# - anonymous_enabled: true/false. Whether the model appears in the public no-login catalog.
|
||||
# - seo_enabled: true/false. Whether a /free/<seo_slug> landing page is generated.
|
||||
# - seo_slug: Stable URL slug for SEO pages. Must be unique. Do NOT change once public.
|
||||
# - seo_title: Optional HTML title tag override for the model's /free/<slug> page.
|
||||
# - seo_description: Optional meta description override for the model's /free/<slug> page.
|
||||
# - quota_reserve_tokens: Tokens reserved before each LLM call for quota enforcement.
|
||||
# Independent of litellm_params.max_tokens. Used by the token quota service.
|
||||
# Image generation notes:
|
||||
# - Image-generation configs use the same GLOBAL ID namespace as chat models.
|
||||
# - Only RPM is relevant for most image-generation APIs.
|
||||
# - The runtime uses litellm.aimage_generation().
|
||||
# - Image billing currently uses billing_tier and model catalog metadata. Keep
|
||||
# quota reserve tuning in code/catalog unless the materializer copies a YAML
|
||||
# key for image quota reservation.
|
||||
|
|
|
|||
|
|
@ -198,79 +198,15 @@ class DocumentStatus:
|
|||
return None
|
||||
|
||||
|
||||
class LiteLLMProvider(StrEnum):
|
||||
"""
|
||||
Enum for LLM providers supported by LiteLLM.
|
||||
"""
|
||||
|
||||
OPENAI = "OPENAI"
|
||||
ANTHROPIC = "ANTHROPIC"
|
||||
GOOGLE = "GOOGLE"
|
||||
AZURE_OPENAI = "AZURE_OPENAI"
|
||||
BEDROCK = "BEDROCK"
|
||||
VERTEX_AI = "VERTEX_AI"
|
||||
GROQ = "GROQ"
|
||||
COHERE = "COHERE"
|
||||
MISTRAL = "MISTRAL"
|
||||
DEEPSEEK = "DEEPSEEK"
|
||||
XAI = "XAI"
|
||||
OPENROUTER = "OPENROUTER"
|
||||
TOGETHER_AI = "TOGETHER_AI"
|
||||
FIREWORKS_AI = "FIREWORKS_AI"
|
||||
REPLICATE = "REPLICATE"
|
||||
PERPLEXITY = "PERPLEXITY"
|
||||
OLLAMA = "OLLAMA"
|
||||
ALIBABA_QWEN = "ALIBABA_QWEN"
|
||||
MOONSHOT = "MOONSHOT"
|
||||
ZHIPU = "ZHIPU"
|
||||
ANYSCALE = "ANYSCALE"
|
||||
DEEPINFRA = "DEEPINFRA"
|
||||
CEREBRAS = "CEREBRAS"
|
||||
SAMBANOVA = "SAMBANOVA"
|
||||
AI21 = "AI21"
|
||||
CLOUDFLARE = "CLOUDFLARE"
|
||||
DATABRICKS = "DATABRICKS"
|
||||
COMETAPI = "COMETAPI"
|
||||
HUGGINGFACE = "HUGGINGFACE"
|
||||
GITHUB_MODELS = "GITHUB_MODELS"
|
||||
MINIMAX = "MINIMAX"
|
||||
CUSTOM = "CUSTOM"
|
||||
class ConnectionScope(StrEnum):
|
||||
GLOBAL = "GLOBAL"
|
||||
SEARCH_SPACE = "SEARCH_SPACE"
|
||||
USER = "USER"
|
||||
|
||||
|
||||
class ImageGenProvider(StrEnum):
|
||||
"""
|
||||
Enum for image generation providers supported by LiteLLM.
|
||||
This is a subset of LLM providers — only those that support image generation.
|
||||
See: https://docs.litellm.ai/docs/image_generation#supported-providers
|
||||
"""
|
||||
|
||||
OPENAI = "OPENAI"
|
||||
AZURE_OPENAI = "AZURE_OPENAI"
|
||||
GOOGLE = "GOOGLE" # Google AI Studio
|
||||
VERTEX_AI = "VERTEX_AI"
|
||||
BEDROCK = "BEDROCK" # AWS Bedrock
|
||||
RECRAFT = "RECRAFT"
|
||||
OPENROUTER = "OPENROUTER"
|
||||
XINFERENCE = "XINFERENCE"
|
||||
NSCALE = "NSCALE"
|
||||
|
||||
|
||||
class VisionProvider(StrEnum):
|
||||
OPENAI = "OPENAI"
|
||||
ANTHROPIC = "ANTHROPIC"
|
||||
GOOGLE = "GOOGLE"
|
||||
AZURE_OPENAI = "AZURE_OPENAI"
|
||||
VERTEX_AI = "VERTEX_AI"
|
||||
BEDROCK = "BEDROCK"
|
||||
XAI = "XAI"
|
||||
OPENROUTER = "OPENROUTER"
|
||||
OLLAMA = "OLLAMA"
|
||||
GROQ = "GROQ"
|
||||
TOGETHER_AI = "TOGETHER_AI"
|
||||
FIREWORKS_AI = "FIREWORKS_AI"
|
||||
DEEPSEEK = "DEEPSEEK"
|
||||
MISTRAL = "MISTRAL"
|
||||
CUSTOM = "CUSTOM"
|
||||
class ModelSource(StrEnum):
|
||||
DISCOVERED = "DISCOVERED"
|
||||
MANUAL = "MANUAL"
|
||||
|
||||
|
||||
class LogLevel(StrEnum):
|
||||
|
|
@ -699,11 +635,11 @@ class NewChatThread(BaseModel, TimestampMixin):
|
|||
default=False,
|
||||
server_default="false",
|
||||
)
|
||||
# Auto (Fastest) model pin for this thread: concrete resolved global LLM
|
||||
# Auto model pin for this thread: concrete resolved global LLM
|
||||
# config id. NULL means no pin; Auto will resolve on the next turn.
|
||||
# Single-writer invariant: only app.services.auto_model_pin_service sets
|
||||
# or clears this column (plus bulk clears when a search space's
|
||||
# agent_llm_id changes). Unindexed: all reads are by primary key.
|
||||
# chat_model_id changes). Unindexed: all reads are by primary key.
|
||||
pinned_llm_config_id = Column(Integer, nullable=True)
|
||||
|
||||
# Surface metadata for first-party SurfSense and external chat threads.
|
||||
|
|
@ -1607,73 +1543,80 @@ class Report(BaseModel, TimestampMixin):
|
|||
thread = relationship("NewChatThread")
|
||||
|
||||
|
||||
class ImageGenerationConfig(BaseModel, TimestampMixin):
|
||||
"""
|
||||
Dedicated configuration table for image generation models.
|
||||
class Connection(BaseModel, TimestampMixin):
|
||||
__tablename__ = "connections"
|
||||
|
||||
Separate from NewLLMConfig because image generation models don't need
|
||||
system_instructions, citations_enabled, or use_default_system_instructions.
|
||||
They only need provider credentials and model parameters.
|
||||
"""
|
||||
|
||||
__tablename__ = "image_generation_configs"
|
||||
|
||||
name = Column(String(100), nullable=False, index=True)
|
||||
description = Column(String(500), nullable=True)
|
||||
|
||||
# Provider & model (uses ImageGenProvider, NOT LiteLLMProvider)
|
||||
provider = Column(SQLAlchemyEnum(ImageGenProvider), nullable=False)
|
||||
custom_provider = Column(String(100), nullable=True)
|
||||
model_name = Column(String(100), nullable=False)
|
||||
|
||||
# Credentials
|
||||
api_key = Column(String, nullable=False)
|
||||
api_base = Column(String(500), nullable=True)
|
||||
api_version = Column(String(50), nullable=True) # Azure-specific
|
||||
|
||||
# Additional litellm parameters
|
||||
litellm_params = Column(JSON, nullable=True, default={})
|
||||
|
||||
# Relationships
|
||||
search_space_id = Column(
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
search_space = relationship(
|
||||
"SearchSpace", back_populates="image_generation_configs"
|
||||
)
|
||||
|
||||
# User who created this config
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
user = relationship("User", back_populates="image_generation_configs")
|
||||
|
||||
|
||||
class VisionLLMConfig(BaseModel, TimestampMixin):
|
||||
__tablename__ = "vision_llm_configs"
|
||||
|
||||
name = Column(String(100), nullable=False, index=True)
|
||||
description = Column(String(500), nullable=True)
|
||||
|
||||
provider = Column(SQLAlchemyEnum(VisionProvider), nullable=False)
|
||||
custom_provider = Column(String(100), nullable=True)
|
||||
model_name = Column(String(100), nullable=False)
|
||||
|
||||
api_key = Column(String, nullable=False)
|
||||
api_base = Column(String(500), nullable=True)
|
||||
api_version = Column(String(50), nullable=True)
|
||||
|
||||
litellm_params = Column(JSON, nullable=True, default={})
|
||||
provider = Column(String(100), nullable=False, index=True)
|
||||
base_url = Column(String(500), nullable=True)
|
||||
api_key = Column(String, nullable=True)
|
||||
extra = Column(JSONB, nullable=False, default=dict, server_default="{}")
|
||||
scope = Column(SQLAlchemyEnum(ConnectionScope), nullable=False, index=True)
|
||||
enabled = Column(Boolean, nullable=False, default=True, server_default="true")
|
||||
|
||||
search_space_id = Column(
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
search_space = relationship("SearchSpace", back_populates="vision_llm_configs")
|
||||
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||||
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
|
||||
search_space = relationship("SearchSpace", back_populates="connections")
|
||||
user = relationship("User", back_populates="connections")
|
||||
models = relationship(
|
||||
"Model",
|
||||
back_populates="connection",
|
||||
order_by="Model.id",
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
CheckConstraint(
|
||||
"(scope = 'GLOBAL' AND search_space_id IS NULL AND user_id IS NULL) OR "
|
||||
"(scope = 'SEARCH_SPACE' AND search_space_id IS NOT NULL AND user_id IS NOT NULL) OR "
|
||||
"(scope = 'USER' AND user_id IS NOT NULL)",
|
||||
name="ck_connections_scope_owner",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class Model(BaseModel, TimestampMixin):
|
||||
__tablename__ = "models"
|
||||
|
||||
connection_id = Column(
|
||||
Integer,
|
||||
ForeignKey("connections.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
model_id = Column(String(255), nullable=False)
|
||||
display_name = Column(String(255), nullable=True)
|
||||
source = Column(
|
||||
SQLAlchemyEnum(ModelSource),
|
||||
nullable=False,
|
||||
default=ModelSource.DISCOVERED,
|
||||
server_default=ModelSource.DISCOVERED.value,
|
||||
)
|
||||
supports_chat = Column(Boolean, nullable=True)
|
||||
max_input_tokens = Column(Integer, nullable=True)
|
||||
supports_image_input = Column(Boolean, nullable=True)
|
||||
supports_tools = Column(Boolean, nullable=True)
|
||||
supports_image_generation = Column(Boolean, nullable=True)
|
||||
capabilities_override = Column(
|
||||
JSONB, nullable=False, default=dict, server_default="{}"
|
||||
)
|
||||
enabled = Column(Boolean, nullable=False, default=True, server_default="true")
|
||||
billing_tier = Column(String(50), nullable=True, index=True)
|
||||
catalog = Column(JSONB, nullable=False, default=dict, server_default="{}")
|
||||
|
||||
connection = relationship("Connection", back_populates="models")
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"connection_id", "model_id", name="uq_models_connection_model_id"
|
||||
),
|
||||
Index("ix_models_model_id", "model_id"),
|
||||
)
|
||||
user = relationship("User", back_populates="vision_llm_configs")
|
||||
|
||||
|
||||
class ImageGeneration(BaseModel, TimestampMixin):
|
||||
|
|
@ -1707,10 +1650,9 @@ class ImageGeneration(BaseModel, TimestampMixin):
|
|||
style = Column(String(50), nullable=True) # Model-specific style parameter
|
||||
response_format = Column(String(50), nullable=True) # "url" or "b64_json"
|
||||
|
||||
# Image generation config reference
|
||||
# 0 = Auto mode (router), negative IDs = global configs from YAML,
|
||||
# positive IDs = ImageGenerationConfig records in DB
|
||||
image_generation_config_id = Column(Integer, nullable=True)
|
||||
# Image generation model provenance.
|
||||
# 0 = Auto mode, negative IDs = GLOBAL models, positive IDs = Model records.
|
||||
image_gen_model_id = Column(Integer, nullable=True)
|
||||
|
||||
# Response data (full litellm response as JSONB) — present on success
|
||||
response_data = Column(JSONB, nullable=True)
|
||||
|
|
@ -1752,19 +1694,19 @@ class SearchSpace(BaseModel, TimestampMixin):
|
|||
|
||||
shared_memory_md = Column(Text, nullable=True, server_default="")
|
||||
|
||||
# Search space-level LLM preferences (shared by all members)
|
||||
# Note: ID values:
|
||||
# - 0: Auto mode (uses LiteLLM Router for load balancing) - default for new search spaces
|
||||
# - Negative IDs: Global configs from YAML
|
||||
# - Positive IDs: Custom configs from DB (NewLLMConfig table)
|
||||
agent_llm_id = Column(
|
||||
Integer, nullable=True, default=0
|
||||
# Connection/model role bindings.
|
||||
# Note: ID values preserve the existing convention:
|
||||
# - 0: Auto mode
|
||||
# - Negative IDs: Global virtual models from global_llm_config.yaml
|
||||
# - Positive IDs: User/search-space models from the models table
|
||||
chat_model_id = Column(
|
||||
Integer, nullable=True, default=0, server_default="0"
|
||||
) # For agent/chat operations, defaults to Auto mode
|
||||
image_generation_config_id = Column(
|
||||
Integer, nullable=True, default=0
|
||||
) # For image generation, defaults to Auto mode
|
||||
vision_llm_config_id = Column(
|
||||
Integer, nullable=True, default=0
|
||||
image_gen_model_id = Column(
|
||||
Integer, nullable=True, default=0, server_default="0"
|
||||
) # For image generation, defaults to Auto mode when eligible
|
||||
vision_model_id = Column(
|
||||
Integer, nullable=True, default=0, server_default="0"
|
||||
) # For vision/screenshot analysis, defaults to Auto mode
|
||||
|
||||
ai_file_sort_enabled = Column(
|
||||
|
|
@ -1836,23 +1778,12 @@ class SearchSpace(BaseModel, TimestampMixin):
|
|||
order_by="SearchSourceConnector.id",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
new_llm_configs = relationship(
|
||||
"NewLLMConfig",
|
||||
connections = relationship(
|
||||
"Connection",
|
||||
back_populates="search_space",
|
||||
order_by="NewLLMConfig.id",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
image_generation_configs = relationship(
|
||||
"ImageGenerationConfig",
|
||||
back_populates="search_space",
|
||||
order_by="ImageGenerationConfig.id",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
vision_llm_configs = relationship(
|
||||
"VisionLLMConfig",
|
||||
back_populates="search_space",
|
||||
order_by="VisionLLMConfig.id",
|
||||
order_by="Connection.id",
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
automations = relationship(
|
||||
|
|
@ -1955,64 +1886,6 @@ class SearchSourceConnector(BaseModel, TimestampMixin):
|
|||
documents = relationship("Document", back_populates="connector")
|
||||
|
||||
|
||||
class NewLLMConfig(BaseModel, TimestampMixin):
|
||||
"""
|
||||
New LLM configuration table that combines model settings with prompt configuration.
|
||||
|
||||
This table provides:
|
||||
- LLM model configuration (provider, model_name, api_key, etc.)
|
||||
- Configurable system instructions (defaults to SURFSENSE_SYSTEM_INSTRUCTIONS)
|
||||
- Citation toggle (enable/disable citation instructions)
|
||||
|
||||
Note: Tools instructions are built by get_tools_instructions(thread_visibility) (personal vs shared memory).
|
||||
"""
|
||||
|
||||
__tablename__ = "new_llm_configs"
|
||||
|
||||
name = Column(String(100), nullable=False, index=True)
|
||||
description = Column(String(500), nullable=True)
|
||||
|
||||
# === LLM Model Configuration (from original LLMConfig, excluding 'language') ===
|
||||
# Provider from the enum
|
||||
provider = Column(SQLAlchemyEnum(LiteLLMProvider), nullable=False)
|
||||
# Custom provider name when provider is CUSTOM
|
||||
custom_provider = Column(String(100), nullable=True)
|
||||
# Just the model name without provider prefix
|
||||
model_name = Column(String(100), nullable=False)
|
||||
# API Key should be encrypted before storing
|
||||
api_key = Column(String, nullable=False)
|
||||
api_base = Column(String(500), nullable=True)
|
||||
# For any other parameters that litellm supports
|
||||
litellm_params = Column(JSON, nullable=True, default={})
|
||||
|
||||
# === Prompt Configuration ===
|
||||
# Configurable system instructions (defaults to SURFSENSE_SYSTEM_INSTRUCTIONS)
|
||||
# Users can customize this from the UI
|
||||
system_instructions = Column(
|
||||
Text,
|
||||
nullable=False,
|
||||
default="", # Empty string means use default SURFSENSE_SYSTEM_INSTRUCTIONS
|
||||
)
|
||||
# Whether to use the default system instructions when system_instructions is empty
|
||||
use_default_system_instructions = Column(Boolean, nullable=False, default=True)
|
||||
|
||||
# Citation toggle - when enabled, SURFSENSE_CITATION_INSTRUCTIONS is injected
|
||||
# When disabled, an anti-citation prompt is injected instead
|
||||
citations_enabled = Column(Boolean, nullable=False, default=True)
|
||||
|
||||
# === Relationships ===
|
||||
search_space_id = Column(
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
search_space = relationship("SearchSpace", back_populates="new_llm_configs")
|
||||
|
||||
# User who created this config
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
user = relationship("User", back_populates="new_llm_configs")
|
||||
|
||||
|
||||
class Log(BaseModel, TimestampMixin):
|
||||
__tablename__ = "logs"
|
||||
|
||||
|
|
@ -2379,22 +2252,8 @@ if config.AUTH_TYPE == "GOOGLE":
|
|||
passive_deletes=True,
|
||||
)
|
||||
|
||||
# LLM configs created by this user
|
||||
new_llm_configs = relationship(
|
||||
"NewLLMConfig",
|
||||
back_populates="user",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
# Image generation configs created by this user
|
||||
image_generation_configs = relationship(
|
||||
"ImageGenerationConfig",
|
||||
back_populates="user",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
vision_llm_configs = relationship(
|
||||
"VisionLLMConfig",
|
||||
connections = relationship(
|
||||
"Connection",
|
||||
back_populates="user",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
|
@ -2525,22 +2384,8 @@ else:
|
|||
passive_deletes=True,
|
||||
)
|
||||
|
||||
# LLM configs created by this user
|
||||
new_llm_configs = relationship(
|
||||
"NewLLMConfig",
|
||||
back_populates="user",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
# Image generation configs created by this user
|
||||
image_generation_configs = relationship(
|
||||
"ImageGenerationConfig",
|
||||
back_populates="user",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
vision_llm_configs = relationship(
|
||||
"VisionLLMConfig",
|
||||
connections = relationship(
|
||||
"Connection",
|
||||
back_populates="user",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -14,6 +14,8 @@ from litellm.exceptions import (
|
|||
)
|
||||
from sqlalchemy.exc import IntegrityError as IntegrityError
|
||||
|
||||
from app.services.llm_error_adapter import LLMErrorCategory, adapt_llm_exception
|
||||
|
||||
# Tuples for use directly in except clauses.
|
||||
RETRYABLE_LLM_ERRORS = (
|
||||
RateLimitError,
|
||||
|
|
@ -97,38 +99,20 @@ def safe_exception_message(exc: Exception) -> str:
|
|||
|
||||
def llm_retryable_message(exc: Exception) -> str:
|
||||
try:
|
||||
if isinstance(exc, RateLimitError):
|
||||
return PipelineMessages.RATE_LIMIT
|
||||
if isinstance(exc, Timeout):
|
||||
return PipelineMessages.LLM_TIMEOUT
|
||||
if isinstance(exc, ServiceUnavailableError):
|
||||
return PipelineMessages.LLM_UNAVAILABLE
|
||||
if isinstance(exc, BadGatewayError):
|
||||
return PipelineMessages.LLM_BAD_GATEWAY
|
||||
if isinstance(exc, InternalServerError):
|
||||
return PipelineMessages.LLM_SERVER_ERROR
|
||||
if isinstance(exc, APIConnectionError):
|
||||
return PipelineMessages.LLM_CONNECTION
|
||||
return safe_exception_message(exc)
|
||||
adapted = adapt_llm_exception(exc)
|
||||
if adapted.category is LLMErrorCategory.UNKNOWN:
|
||||
return safe_exception_message(exc)
|
||||
return adapted.user_message
|
||||
except Exception:
|
||||
return "Something went wrong when calling the LLM."
|
||||
|
||||
|
||||
def llm_permanent_message(exc: Exception) -> str:
|
||||
try:
|
||||
if isinstance(exc, AuthenticationError):
|
||||
return PipelineMessages.LLM_AUTH
|
||||
if isinstance(exc, PermissionDeniedError):
|
||||
return PipelineMessages.LLM_PERMISSION
|
||||
if isinstance(exc, NotFoundError):
|
||||
return PipelineMessages.LLM_NOT_FOUND
|
||||
if isinstance(exc, BadRequestError):
|
||||
return PipelineMessages.LLM_BAD_REQUEST
|
||||
if isinstance(exc, UnprocessableEntityError):
|
||||
return PipelineMessages.LLM_UNPROCESSABLE
|
||||
if isinstance(exc, APIResponseValidationError):
|
||||
return PipelineMessages.LLM_RESPONSE
|
||||
return safe_exception_message(exc)
|
||||
adapted = adapt_llm_exception(exc)
|
||||
if adapted.category is LLMErrorCategory.UNKNOWN:
|
||||
return safe_exception_message(exc)
|
||||
return adapted.user_message
|
||||
except Exception:
|
||||
return "Something went wrong when calling the LLM."
|
||||
|
||||
|
|
|
|||
|
|
@ -82,7 +82,7 @@ def build_configurable_system_prompt(
|
|||
*,
|
||||
model_name: str | None = None,
|
||||
) -> str:
|
||||
"""Build a configurable SurfSense system prompt (NewLLMConfig path).
|
||||
"""Build a configurable SurfSense system prompt.
|
||||
|
||||
See :func:`app.prompts.system_prompt_composer.composer.compose_system_prompt`
|
||||
for full parameter docs.
|
||||
|
|
@ -104,7 +104,7 @@ def build_configurable_system_prompt(
|
|||
def get_default_system_instructions() -> str:
|
||||
"""Return the default ``<system_instruction>`` block (no tools / citations).
|
||||
|
||||
Useful for populating the UI when seeding ``NewLLMConfig.system_instructions``.
|
||||
Useful for populating the UI when editing custom system instructions.
|
||||
The output reflects the current fragment tree, not a baked-in constant.
|
||||
"""
|
||||
resolved_today = datetime.now(UTC).date().isoformat()
|
||||
|
|
|
|||
|
|
@ -348,8 +348,7 @@ def compose_system_prompt(
|
|||
mcp_connector_tools: ``{server_name: [tool_names...]}`` to inject
|
||||
an explicit MCP routing block.
|
||||
custom_system_instructions: Free-form instructions that override
|
||||
the default ``<system_instruction>`` block (legacy support
|
||||
for ``NewLLMConfig.system_instructions``).
|
||||
the default ``<system_instruction>`` block.
|
||||
use_default_system_instructions: When ``custom_system_instructions``
|
||||
is empty/None, fall back to defaults (legacy semantics).
|
||||
citations_enabled: Include ``citations_on.md`` (true) or
|
||||
|
|
|
|||
|
|
@ -44,9 +44,9 @@ from .logs_routes import router as logs_router
|
|||
from .luma_add_connector_route import router as luma_add_connector_router
|
||||
from .mcp_oauth_route import router as mcp_oauth_router
|
||||
from .memory_routes import router as memory_router
|
||||
from .model_connections_routes import router as model_connections_router
|
||||
from .model_list_routes import router as model_list_router
|
||||
from .new_chat_routes import router as new_chat_router
|
||||
from .new_llm_config_routes import router as new_llm_config_router
|
||||
from .notes_routes import router as notes_router
|
||||
from .notion_add_connector_route import router as notion_add_connector_router
|
||||
from .obsidian_plugin_routes import router as obsidian_plugin_router
|
||||
|
|
@ -63,7 +63,6 @@ from .stripe_routes import router as stripe_router
|
|||
from .team_memory_routes import router as team_memory_router
|
||||
from .teams_add_connector_route import router as teams_add_connector_router
|
||||
from .video_presentations_routes import router as video_presentations_router
|
||||
from .vision_llm_routes import router as vision_llm_router
|
||||
from .youtube_routes import router as youtube_router
|
||||
|
||||
router = APIRouter()
|
||||
|
|
@ -98,7 +97,6 @@ router.include_router(
|
|||
) # Video presentation status and streaming
|
||||
router.include_router(reports_router) # Report CRUD and multi-format export
|
||||
router.include_router(image_generation_router) # Image generation via litellm
|
||||
router.include_router(vision_llm_router) # Vision LLM configs for screenshot analysis
|
||||
router.include_router(search_source_connectors_router)
|
||||
router.include_router(google_calendar_add_connector_router)
|
||||
router.include_router(google_gmail_add_connector_router)
|
||||
|
|
@ -116,7 +114,7 @@ router.include_router(jira_add_connector_router)
|
|||
router.include_router(confluence_add_connector_router)
|
||||
router.include_router(clickup_add_connector_router)
|
||||
router.include_router(dropbox_add_connector_router)
|
||||
router.include_router(new_llm_config_router) # LLM configs with prompt configuration
|
||||
router.include_router(model_connections_router) # Connection-centric model catalog
|
||||
router.include_router(model_list_router) # Dynamic model catalogue from OpenRouter
|
||||
router.include_router(logs_router)
|
||||
router.include_router(circleback_webhook_router) # Circleback meeting webhooks
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from app.etl_pipeline.file_classifier import (
|
|||
PLAINTEXT_EXTENSIONS,
|
||||
)
|
||||
from app.rate_limiter import limiter
|
||||
from app.tasks.chat.streaming.errors.classifier import classify_stream_exception
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -98,7 +99,6 @@ class AnonQuotaResponse(BaseModel):
|
|||
class AnonModelResponse(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: str | None = None
|
||||
provider: str
|
||||
model_name: str
|
||||
billing_tier: str = "free"
|
||||
|
|
@ -131,8 +131,7 @@ async def list_anonymous_models():
|
|||
AnonModelResponse(
|
||||
id=cfg.get("id", 0),
|
||||
name=cfg.get("name", ""),
|
||||
description=cfg.get("description"),
|
||||
provider=cfg.get("provider", ""),
|
||||
provider=cfg.get("provider") or cfg.get("litellm_provider", ""),
|
||||
model_name=cfg.get("model_name", ""),
|
||||
billing_tier=cfg.get("billing_tier", "free"),
|
||||
is_premium=cfg.get("billing_tier", "free") == "premium",
|
||||
|
|
@ -160,8 +159,7 @@ async def get_anonymous_model(slug: str):
|
|||
return AnonModelResponse(
|
||||
id=cfg.get("id", 0),
|
||||
name=cfg.get("name", ""),
|
||||
description=cfg.get("description"),
|
||||
provider=cfg.get("provider", ""),
|
||||
provider=cfg.get("provider") or cfg.get("litellm_provider", ""),
|
||||
model_name=cfg.get("model_name", ""),
|
||||
billing_tier=cfg.get("billing_tier", "free"),
|
||||
is_premium=cfg.get("billing_tier", "free") == "premium",
|
||||
|
|
@ -474,7 +472,15 @@ async def stream_anonymous_chat(
|
|||
except Exception as e:
|
||||
logger.exception("Anonymous chat stream error")
|
||||
await TokenQuotaService.anon_release(session_key, ip_key, request_id)
|
||||
yield streaming_service.format_error(f"Error during chat: {e!s}")
|
||||
_, error_code, _, _, user_message, extra = classify_stream_exception(
|
||||
e,
|
||||
flow_label="chat",
|
||||
)
|
||||
yield streaming_service.format_error(
|
||||
user_message,
|
||||
error_code=error_code,
|
||||
extra=extra,
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
finally:
|
||||
await TokenQuotaService.anon_release_stream_slot(client_ip)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,5 @@
|
|||
"""
|
||||
Image Generation routes:
|
||||
- CRUD for ImageGenerationConfig (user-created image model configs)
|
||||
- Global image gen configs endpoint (from YAML)
|
||||
- Image generation execution (calls litellm.aimage_generation())
|
||||
- CRUD for ImageGeneration records (results)
|
||||
- Image serving endpoint (serves b64_json images from DB, protected by signed tokens)
|
||||
|
|
@ -16,11 +14,12 @@ from litellm import aimage_generation
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
ImageGeneration,
|
||||
ImageGenerationConfig,
|
||||
Model,
|
||||
Permission,
|
||||
SearchSpace,
|
||||
SearchSpaceMembership,
|
||||
|
|
@ -28,14 +27,14 @@ from app.db import (
|
|||
get_async_session,
|
||||
)
|
||||
from app.schemas import (
|
||||
GlobalImageGenConfigRead,
|
||||
ImageGenerationConfigCreate,
|
||||
ImageGenerationConfigRead,
|
||||
ImageGenerationConfigUpdate,
|
||||
ImageGenerationCreate,
|
||||
ImageGenerationListRead,
|
||||
ImageGenerationRead,
|
||||
)
|
||||
from app.services.auto_model_pin_service import (
|
||||
auto_model_candidates,
|
||||
choose_auto_model_candidate,
|
||||
)
|
||||
from app.services.billable_calls import (
|
||||
DEFAULT_IMAGE_RESERVE_MICROS,
|
||||
QuotaInsufficientError,
|
||||
|
|
@ -43,10 +42,10 @@ from app.services.billable_calls import (
|
|||
)
|
||||
from app.services.image_gen_router_service import (
|
||||
IMAGE_GEN_AUTO_MODE_ID,
|
||||
ImageGenRouterService,
|
||||
is_image_gen_auto_mode,
|
||||
)
|
||||
from app.services.provider_api_base import resolve_api_base
|
||||
from app.services.model_capabilities import has_capability
|
||||
from app.services.model_resolver import to_litellm
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
from app.utils.signed_image_urls import verify_image_token
|
||||
|
|
@ -54,52 +53,16 @@ from app.utils.signed_image_urls import verify_image_token
|
|||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Provider mapping for building litellm model strings.
|
||||
# Only includes providers that support image generation.
|
||||
# See: https://docs.litellm.ai/docs/image_generation#supported-providers
|
||||
_PROVIDER_MAP = {
|
||||
"OPENAI": "openai",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"GOOGLE": "gemini", # Google AI Studio
|
||||
"VERTEX_AI": "vertex_ai",
|
||||
"BEDROCK": "bedrock", # AWS Bedrock
|
||||
"RECRAFT": "recraft",
|
||||
"OPENROUTER": "openrouter",
|
||||
"XINFERENCE": "xinference",
|
||||
"NSCALE": "nscale",
|
||||
}
|
||||
|
||||
def _get_global_model(model_id: int) -> dict | None:
|
||||
return next((m for m in config.GLOBAL_MODELS if m.get("id") == model_id), None)
|
||||
|
||||
|
||||
def _get_global_image_gen_config(config_id: int) -> dict | None:
|
||||
"""Get a global image generation configuration by ID (negative IDs)."""
|
||||
if config_id == IMAGE_GEN_AUTO_MODE_ID:
|
||||
return {
|
||||
"id": IMAGE_GEN_AUTO_MODE_ID,
|
||||
"name": "Auto (Fastest)",
|
||||
"provider": "AUTO",
|
||||
"model_name": "auto",
|
||||
"is_auto_mode": True,
|
||||
}
|
||||
if config_id > 0:
|
||||
return None
|
||||
for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS:
|
||||
if cfg.get("id") == config_id:
|
||||
return cfg
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str:
|
||||
"""Resolve the LiteLLM provider prefix used in model strings."""
|
||||
if custom_provider:
|
||||
return custom_provider
|
||||
return _PROVIDER_MAP.get(provider.upper(), provider.lower())
|
||||
|
||||
|
||||
def _build_model_string(
|
||||
provider: str, model_name: str, custom_provider: str | None
|
||||
) -> str:
|
||||
"""Build a litellm model string from provider + model_name."""
|
||||
return f"{_resolve_provider_prefix(provider, custom_provider)}/{model_name}"
|
||||
def _get_global_connection(connection_id: int) -> dict | None:
|
||||
return next(
|
||||
(c for c in config.GLOBAL_CONNECTIONS if c.get("id") == connection_id),
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
async def _resolve_billing_for_image_gen(
|
||||
|
|
@ -115,34 +78,41 @@ async def _resolve_billing_for_image_gen(
|
|||
config that will actually run, and so we don't open an
|
||||
``ImageGeneration`` row for a request that's about to 402.
|
||||
|
||||
User-owned (positive ID) BYOK configs are always free — they cost
|
||||
the user nothing on our side. Auto mode currently treats as free
|
||||
because the underlying router can dispatch to either premium or
|
||||
free YAML configs and we don't surface the resolved deployment up
|
||||
here yet. Bringing Auto under premium billing would require
|
||||
threading the chosen deployment back from ``ImageGenRouterService``.
|
||||
User-owned (positive ID) BYOK models are always free — they cost
|
||||
the user nothing on our side. Auto mode resolves to one concrete
|
||||
global or BYOK model before billing is calculated.
|
||||
"""
|
||||
resolved_id = config_id
|
||||
if resolved_id is None:
|
||||
resolved_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID
|
||||
resolved_id = search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID
|
||||
|
||||
if is_image_gen_auto_mode(resolved_id):
|
||||
return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS)
|
||||
candidates = await auto_model_candidates(
|
||||
session,
|
||||
search_space_id=search_space.id,
|
||||
user_id=search_space.user_id,
|
||||
capability="image_gen",
|
||||
)
|
||||
if not candidates:
|
||||
return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS)
|
||||
selected = choose_auto_model_candidate(candidates, search_space.id)
|
||||
resolved_id = int(selected["id"])
|
||||
|
||||
if resolved_id < 0:
|
||||
cfg = _get_global_image_gen_config(resolved_id) or {}
|
||||
billing_tier = str(cfg.get("billing_tier", "free")).lower()
|
||||
base_model = _build_model_string(
|
||||
cfg.get("provider", ""),
|
||||
cfg.get("model_name", ""),
|
||||
cfg.get("custom_provider"),
|
||||
)
|
||||
global_model = _get_global_model(resolved_id) or {}
|
||||
global_connection = _get_global_connection(global_model.get("connection_id", 0))
|
||||
billing_tier = str(global_model.get("billing_tier", "free")).lower()
|
||||
if global_connection and global_model.get("model_id"):
|
||||
base_model, _ = to_litellm(global_connection, global_model["model_id"])
|
||||
else:
|
||||
base_model = "global_image_model"
|
||||
catalog = global_model.get("catalog") or {}
|
||||
reserve_micros = int(
|
||||
cfg.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS
|
||||
catalog.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS
|
||||
)
|
||||
return (billing_tier, base_model, reserve_micros)
|
||||
|
||||
# Positive ID = user-owned BYOK image-gen config — always free.
|
||||
# Positive ID = user-owned BYOK image-gen model — always free.
|
||||
return ("free", "user_byok", DEFAULT_IMAGE_RESERVE_MICROS)
|
||||
|
||||
|
||||
|
|
@ -155,14 +125,14 @@ async def _execute_image_generation(
|
|||
Call litellm.aimage_generation() with the appropriate config.
|
||||
|
||||
Resolution order:
|
||||
1. Explicit image_generation_config_id on the request
|
||||
2. Search space's image_generation_config_id preference
|
||||
1. Explicit image_gen_model_id on the request
|
||||
2. Search space's image_gen_model_id preference
|
||||
3. Falls back to Auto mode if available
|
||||
"""
|
||||
config_id = image_gen.image_generation_config_id
|
||||
config_id = image_gen.image_gen_model_id
|
||||
if config_id is None:
|
||||
config_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID
|
||||
image_gen.image_generation_config_id = config_id
|
||||
config_id = search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID
|
||||
image_gen.image_gen_model_id = config_id
|
||||
|
||||
# Build kwargs
|
||||
gen_kwargs = {}
|
||||
|
|
@ -178,36 +148,30 @@ async def _execute_image_generation(
|
|||
gen_kwargs["response_format"] = image_gen.response_format
|
||||
|
||||
if is_image_gen_auto_mode(config_id):
|
||||
if not ImageGenRouterService.is_initialized():
|
||||
raise ValueError(
|
||||
"Auto mode requested but Image Generation Router not initialized. "
|
||||
"Ensure global_llm_config.yaml has global_image_generation_configs."
|
||||
)
|
||||
response = await ImageGenRouterService.aimage_generation(
|
||||
prompt=image_gen.prompt, model="auto", **gen_kwargs
|
||||
candidates = await auto_model_candidates(
|
||||
session,
|
||||
search_space_id=search_space.id,
|
||||
user_id=search_space.user_id,
|
||||
capability="image_gen",
|
||||
)
|
||||
elif config_id < 0:
|
||||
# Global config from YAML
|
||||
cfg = _get_global_image_gen_config(config_id)
|
||||
if not cfg:
|
||||
raise ValueError(f"Global image generation config {config_id} not found")
|
||||
if not candidates:
|
||||
raise ValueError("No image-generation models are available for Auto mode")
|
||||
config_id = int(choose_auto_model_candidate(candidates, search_space.id)["id"])
|
||||
image_gen.image_gen_model_id = config_id
|
||||
|
||||
provider_prefix = _resolve_provider_prefix(
|
||||
cfg.get("provider", ""), cfg.get("custom_provider")
|
||||
if config_id < 0:
|
||||
global_model = _get_global_model(config_id)
|
||||
if not global_model or not has_capability(global_model, "image_gen"):
|
||||
raise ValueError(f"Global image generation model {config_id} not found")
|
||||
global_connection = _get_global_connection(global_model["connection_id"])
|
||||
if not global_connection:
|
||||
raise ValueError(f"Global connection for image model {config_id} not found")
|
||||
|
||||
model_string, resolved_kwargs = to_litellm(
|
||||
global_connection,
|
||||
global_model["model_id"],
|
||||
)
|
||||
model_string = f"{provider_prefix}/{cfg['model_name']}"
|
||||
gen_kwargs["api_key"] = cfg.get("api_key")
|
||||
api_base = resolve_api_base(
|
||||
provider=cfg.get("provider"),
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=cfg.get("api_base"),
|
||||
)
|
||||
if api_base:
|
||||
gen_kwargs["api_base"] = api_base
|
||||
if cfg.get("api_version"):
|
||||
gen_kwargs["api_version"] = cfg["api_version"]
|
||||
if cfg.get("litellm_params"):
|
||||
gen_kwargs.update(cfg["litellm_params"])
|
||||
gen_kwargs.update(resolved_kwargs)
|
||||
|
||||
# User model override
|
||||
if image_gen.model:
|
||||
|
|
@ -217,30 +181,28 @@ async def _execute_image_generation(
|
|||
prompt=image_gen.prompt, model=model_string, **gen_kwargs
|
||||
)
|
||||
else:
|
||||
# Positive ID = DB ImageGenerationConfig
|
||||
# Positive ID = Model + Connection
|
||||
result = await session.execute(
|
||||
select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id)
|
||||
select(Model)
|
||||
.options(selectinload(Model.connection))
|
||||
.filter(Model.id == config_id, Model.enabled.is_(True))
|
||||
)
|
||||
db_cfg = result.scalars().first()
|
||||
if not db_cfg:
|
||||
raise ValueError(f"Image generation config {config_id} not found")
|
||||
db_model = result.scalars().first()
|
||||
if not db_model or not db_model.connection or not db_model.connection.enabled:
|
||||
raise ValueError(f"Image generation model {config_id} not found")
|
||||
conn = db_model.connection
|
||||
if conn.search_space_id is not None and conn.search_space_id != search_space.id:
|
||||
raise ValueError(f"Image generation model {config_id} not found")
|
||||
if conn.user_id is not None and conn.user_id != search_space.user_id:
|
||||
raise ValueError(f"Image generation model {config_id} not found")
|
||||
if not has_capability(db_model, "image_gen"):
|
||||
raise ValueError(f"Model {config_id} is not image-generation capable")
|
||||
|
||||
provider_prefix = _resolve_provider_prefix(
|
||||
db_cfg.provider.value, db_cfg.custom_provider
|
||||
model_string, resolved_kwargs = to_litellm(
|
||||
db_model.connection,
|
||||
db_model.model_id,
|
||||
)
|
||||
model_string = f"{provider_prefix}/{db_cfg.model_name}"
|
||||
gen_kwargs["api_key"] = db_cfg.api_key
|
||||
api_base = resolve_api_base(
|
||||
provider=db_cfg.provider.value,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=db_cfg.api_base,
|
||||
)
|
||||
if api_base:
|
||||
gen_kwargs["api_base"] = api_base
|
||||
if db_cfg.api_version:
|
||||
gen_kwargs["api_version"] = db_cfg.api_version
|
||||
if db_cfg.litellm_params:
|
||||
gen_kwargs.update(db_cfg.litellm_params)
|
||||
gen_kwargs.update(resolved_kwargs)
|
||||
|
||||
# User model override
|
||||
if image_gen.model:
|
||||
|
|
@ -260,266 +222,6 @@ async def _execute_image_generation(
|
|||
image_gen.model = hidden["model"]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Global Image Generation Configs (from YAML)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/global-image-generation-configs",
|
||||
response_model=list[GlobalImageGenConfigRead],
|
||||
)
|
||||
async def get_global_image_gen_configs(
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Get all global image generation configs. API keys are hidden."""
|
||||
try:
|
||||
global_configs = config.GLOBAL_IMAGE_GEN_CONFIGS
|
||||
safe_configs = []
|
||||
|
||||
if global_configs and len(global_configs) > 0:
|
||||
safe_configs.append(
|
||||
{
|
||||
"id": 0,
|
||||
"name": "Auto (Fastest)",
|
||||
"description": "Automatically routes across available image generation providers.",
|
||||
"provider": "AUTO",
|
||||
"custom_provider": None,
|
||||
"model_name": "auto",
|
||||
"api_base": None,
|
||||
"api_version": None,
|
||||
"litellm_params": {},
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
# Auto mode currently treated as free until per-deployment
|
||||
# billing-tier surfacing lands (see _resolve_billing_for_image_gen).
|
||||
"billing_tier": "free",
|
||||
"is_premium": False,
|
||||
}
|
||||
)
|
||||
|
||||
for cfg in global_configs:
|
||||
billing_tier = str(cfg.get("billing_tier", "free")).lower()
|
||||
safe_configs.append(
|
||||
{
|
||||
"id": cfg.get("id"),
|
||||
"name": cfg.get("name"),
|
||||
"description": cfg.get("description"),
|
||||
"provider": cfg.get("provider"),
|
||||
"custom_provider": cfg.get("custom_provider"),
|
||||
"model_name": cfg.get("model_name"),
|
||||
"api_base": cfg.get("api_base") or None,
|
||||
"api_version": cfg.get("api_version") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
"billing_tier": billing_tier,
|
||||
# Mirror chat (``new_llm_config_routes``) so the new-chat
|
||||
# selector's premium badge logic keys off the same
|
||||
# field across chat / image / vision tabs.
|
||||
"is_premium": billing_tier == "premium",
|
||||
"quota_reserve_micros": cfg.get("quota_reserve_micros"),
|
||||
}
|
||||
)
|
||||
|
||||
return safe_configs
|
||||
except Exception as e:
|
||||
logger.exception("Failed to fetch global image generation configs")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch configs: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# ImageGenerationConfig CRUD
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.post("/image-generation-configs", response_model=ImageGenerationConfigRead)
|
||||
async def create_image_gen_config(
|
||||
config_data: ImageGenerationConfigCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Create a new image generation config for a search space."""
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
config_data.search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_CREATE.value,
|
||||
"You don't have permission to create image generation configs in this search space",
|
||||
)
|
||||
|
||||
db_config = ImageGenerationConfig(**config_data.model_dump(), user_id=user.id)
|
||||
session.add(db_config)
|
||||
await session.commit()
|
||||
await session.refresh(db_config)
|
||||
return db_config
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.exception("Failed to create ImageGenerationConfig")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to create config: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/image-generation-configs", response_model=list[ImageGenerationConfigRead])
|
||||
async def list_image_gen_configs(
|
||||
search_space_id: int,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""List image generation configs for a search space."""
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_READ.value,
|
||||
"You don't have permission to view image generation configs in this search space",
|
||||
)
|
||||
|
||||
result = await session.execute(
|
||||
select(ImageGenerationConfig)
|
||||
.filter(ImageGenerationConfig.search_space_id == search_space_id)
|
||||
.order_by(ImageGenerationConfig.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Failed to list ImageGenerationConfigs")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch configs: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get(
|
||||
"/image-generation-configs/{config_id}", response_model=ImageGenerationConfigRead
|
||||
)
|
||||
async def get_image_gen_config(
|
||||
config_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Get a specific image generation config by ID."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id)
|
||||
)
|
||||
db_config = result.scalars().first()
|
||||
if not db_config:
|
||||
raise HTTPException(status_code=404, detail="Config not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
db_config.search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_READ.value,
|
||||
"You don't have permission to view image generation configs in this search space",
|
||||
)
|
||||
return db_config
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Failed to get ImageGenerationConfig")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch config: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.put(
|
||||
"/image-generation-configs/{config_id}", response_model=ImageGenerationConfigRead
|
||||
)
|
||||
async def update_image_gen_config(
|
||||
config_id: int,
|
||||
update_data: ImageGenerationConfigUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Update an existing image generation config."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id)
|
||||
)
|
||||
db_config = result.scalars().first()
|
||||
if not db_config:
|
||||
raise HTTPException(status_code=404, detail="Config not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
db_config.search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_CREATE.value,
|
||||
"You don't have permission to update image generation configs in this search space",
|
||||
)
|
||||
|
||||
for key, value in update_data.model_dump(exclude_unset=True).items():
|
||||
setattr(db_config, key, value)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(db_config)
|
||||
return db_config
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.exception("Failed to update ImageGenerationConfig")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to update config: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.delete("/image-generation-configs/{config_id}", response_model=dict)
|
||||
async def delete_image_gen_config(
|
||||
config_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Delete an image generation config."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id)
|
||||
)
|
||||
db_config = result.scalars().first()
|
||||
if not db_config:
|
||||
raise HTTPException(status_code=404, detail="Config not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
db_config.search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_DELETE.value,
|
||||
"You don't have permission to delete image generation configs in this search space",
|
||||
)
|
||||
|
||||
await session.delete(db_config)
|
||||
await session.commit()
|
||||
return {
|
||||
"message": "Image generation config deleted successfully",
|
||||
"id": config_id,
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.exception("Failed to delete ImageGenerationConfig")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to delete config: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Image Generation Execution + Results CRUD
|
||||
# =============================================================================
|
||||
|
|
@ -568,7 +270,7 @@ async def create_image_generation(
|
|||
raise HTTPException(status_code=404, detail="Search space not found")
|
||||
|
||||
billing_tier, base_model, reserve_micros = await _resolve_billing_for_image_gen(
|
||||
session, data.image_generation_config_id, search_space
|
||||
session, data.image_gen_model_id, search_space
|
||||
)
|
||||
|
||||
# billable_call runs OUTSIDE the inner try/except so QuotaInsufficientError
|
||||
|
|
@ -594,7 +296,7 @@ async def create_image_generation(
|
|||
size=data.size,
|
||||
style=data.style,
|
||||
response_format=data.response_format,
|
||||
image_generation_config_id=data.image_generation_config_id,
|
||||
image_gen_model_id=data.image_gen_model_id,
|
||||
search_space_id=data.search_space_id,
|
||||
created_by_id=user.id,
|
||||
)
|
||||
|
|
|
|||
805
surfsense_backend/app/routes/model_connections_routes.py
Normal file
805
surfsense_backend/app/routes/model_connections_routes.py
Normal file
|
|
@ -0,0 +1,805 @@
|
|||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
Connection,
|
||||
ConnectionScope,
|
||||
Model,
|
||||
ModelSource,
|
||||
NewChatThread,
|
||||
Permission,
|
||||
SearchSpace,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.schemas import (
|
||||
ConnectionCreate,
|
||||
ConnectionRead,
|
||||
ConnectionUpdate,
|
||||
ModelCreate,
|
||||
ModelPreviewRead,
|
||||
ModelProviderRead,
|
||||
ModelRead,
|
||||
ModelRolesRead,
|
||||
ModelRolesUpdate,
|
||||
ModelsBulkUpdate,
|
||||
ModelSelection,
|
||||
ModelTestPreview,
|
||||
ModelUpdate,
|
||||
VerifyConnectionResponse,
|
||||
)
|
||||
from app.services.model_capabilities import has_capability
|
||||
from app.services.model_connection_service import (
|
||||
ModelDiscoveryError,
|
||||
derive_capabilities,
|
||||
discover_models,
|
||||
test_model,
|
||||
verify_connection,
|
||||
)
|
||||
from app.services.provider_registry import REGISTRY
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _model_read(model: Model | dict) -> ModelRead:
|
||||
return ModelRead.model_validate(model)
|
||||
|
||||
|
||||
def _preview_model_read(item: dict) -> ModelPreviewRead:
|
||||
return ModelPreviewRead(
|
||||
model_id=item["model_id"],
|
||||
display_name=item.get("display_name"),
|
||||
source=item.get("source", ModelSource.DISCOVERED),
|
||||
supports_chat=item.get("supports_chat"),
|
||||
max_input_tokens=item.get("max_input_tokens"),
|
||||
supports_image_input=item.get("supports_image_input"),
|
||||
supports_tools=item.get("supports_tools"),
|
||||
supports_image_generation=item.get("supports_image_generation"),
|
||||
enabled=item.get("enabled", False),
|
||||
metadata=item.get("metadata") or item.get("catalog") or {},
|
||||
)
|
||||
|
||||
|
||||
def _connection_read(
|
||||
conn: Connection | dict, models: list[Model | dict] | None = None
|
||||
) -> ConnectionRead:
|
||||
if isinstance(conn, dict):
|
||||
payload = {
|
||||
**conn,
|
||||
"has_api_key": bool(conn.get("api_key")),
|
||||
"api_key": None,
|
||||
"models": [_model_read(model) for model in (models or [])],
|
||||
}
|
||||
payload.pop("api_key", None)
|
||||
return ConnectionRead.model_validate(payload)
|
||||
|
||||
return ConnectionRead(
|
||||
id=conn.id,
|
||||
provider=conn.provider,
|
||||
base_url=conn.base_url,
|
||||
api_key=conn.api_key,
|
||||
extra=conn.extra or {},
|
||||
scope=conn.scope,
|
||||
search_space_id=conn.search_space_id,
|
||||
user_id=conn.user_id,
|
||||
enabled=conn.enabled,
|
||||
has_api_key=bool(conn.api_key),
|
||||
models=[_model_read(model) for model in (models or [])],
|
||||
created_at=conn.created_at,
|
||||
)
|
||||
|
||||
|
||||
def _apply_model_facts(model: Model, facts: dict) -> None:
|
||||
model.supports_chat = facts.get("supports_chat")
|
||||
model.max_input_tokens = facts.get("max_input_tokens")
|
||||
model.supports_image_input = facts.get("supports_image_input")
|
||||
model.supports_tools = facts.get("supports_tools")
|
||||
model.supports_image_generation = facts.get("supports_image_generation")
|
||||
|
||||
|
||||
def _complete_selection_facts(conn: Connection, selection: ModelSelection) -> dict:
|
||||
facts = selection.model_dump()
|
||||
derived = derive_capabilities(conn, selection.model_id.strip(), selection.metadata)
|
||||
for key, value in derived.items():
|
||||
if facts.get(key) is None:
|
||||
facts[key] = value
|
||||
return facts
|
||||
|
||||
|
||||
def _selection_to_model(conn: Connection, selection: ModelSelection) -> Model:
|
||||
source = (
|
||||
selection.source
|
||||
if isinstance(selection.source, ModelSource)
|
||||
else ModelSource(selection.source)
|
||||
)
|
||||
model = Model(
|
||||
connection_id=conn.id,
|
||||
model_id=selection.model_id.strip(),
|
||||
display_name=selection.display_name,
|
||||
source=source,
|
||||
capabilities_override={},
|
||||
enabled=selection.enabled,
|
||||
catalog=selection.metadata,
|
||||
)
|
||||
_apply_model_facts(model, _complete_selection_facts(conn, selection))
|
||||
return model
|
||||
|
||||
|
||||
def _default_model_for(models: list[Model], capability: str) -> int | None:
|
||||
for model in models:
|
||||
if model.enabled and has_capability(model, capability):
|
||||
return model.id
|
||||
return None
|
||||
|
||||
|
||||
async def _load_role_model(
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
model_id: int,
|
||||
) -> Model | dict | None:
|
||||
if model_id < 0:
|
||||
return next(
|
||||
(model for model in config.GLOBAL_MODELS if model.get("id") == model_id),
|
||||
None,
|
||||
)
|
||||
|
||||
result = await session.execute(
|
||||
select(Model)
|
||||
.options(selectinload(Model.connection))
|
||||
.where(Model.id == model_id)
|
||||
)
|
||||
model = result.scalars().first()
|
||||
if model is None or model.connection.search_space_id != search_space_id:
|
||||
return None
|
||||
return model
|
||||
|
||||
|
||||
def _role_model_enabled(model: Model | dict) -> bool:
|
||||
if isinstance(model, dict):
|
||||
return bool(model.get("enabled", True))
|
||||
return bool(model.enabled and model.connection.enabled)
|
||||
|
||||
|
||||
async def _validate_role_model_id(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
search_space_id: int,
|
||||
model_id: int | None,
|
||||
capability: str,
|
||||
) -> int:
|
||||
if model_id is None or model_id == 0:
|
||||
return 0
|
||||
|
||||
model = await _load_role_model(session, search_space_id, model_id)
|
||||
if model and _role_model_enabled(model) and has_capability(model, capability):
|
||||
return model_id
|
||||
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Selected model is not available for {capability}",
|
||||
)
|
||||
|
||||
|
||||
async def _resolve_role_model_id(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
search_space_id: int,
|
||||
model_id: int | None,
|
||||
capability: str,
|
||||
) -> int:
|
||||
try:
|
||||
return await _validate_role_model_id(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
model_id=model_id,
|
||||
capability=capability,
|
||||
)
|
||||
except HTTPException:
|
||||
return 0
|
||||
|
||||
|
||||
async def _clear_invalid_roles(
|
||||
session: AsyncSession, search_space_id: int
|
||||
) -> SearchSpace:
|
||||
search_space = await _get_search_space(session, search_space_id)
|
||||
search_space.chat_model_id = await _resolve_role_model_id(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
model_id=search_space.chat_model_id,
|
||||
capability="chat",
|
||||
)
|
||||
search_space.vision_model_id = await _resolve_role_model_id(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
model_id=search_space.vision_model_id,
|
||||
capability="vision",
|
||||
)
|
||||
search_space.image_gen_model_id = await _resolve_role_model_id(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
model_id=search_space.image_gen_model_id,
|
||||
capability="image_gen",
|
||||
)
|
||||
return search_space
|
||||
|
||||
|
||||
async def _default_unset_roles(
|
||||
session: AsyncSession,
|
||||
conn: Connection,
|
||||
models: list[Model],
|
||||
) -> None:
|
||||
if conn.scope != ConnectionScope.SEARCH_SPACE or conn.search_space_id is None:
|
||||
return
|
||||
search_space = await _get_search_space(session, conn.search_space_id)
|
||||
if search_space.chat_model_id is None:
|
||||
search_space.chat_model_id = _default_model_for(models, "chat")
|
||||
if search_space.vision_model_id is None:
|
||||
vision_default = None
|
||||
if search_space.chat_model_id:
|
||||
chat_model = next(
|
||||
(m for m in models if m.id == search_space.chat_model_id), None
|
||||
)
|
||||
if chat_model and has_capability(chat_model, "vision"):
|
||||
vision_default = chat_model.id
|
||||
search_space.vision_model_id = vision_default or _default_model_for(
|
||||
models, "vision"
|
||||
)
|
||||
if search_space.image_gen_model_id is None:
|
||||
search_space.image_gen_model_id = _default_model_for(models, "image_gen")
|
||||
|
||||
|
||||
@router.get("/model-providers", response_model=list[ModelProviderRead])
|
||||
async def list_model_providers(user: User = Depends(current_active_user)):
|
||||
del user
|
||||
local_only = {"ollama_chat", "lm_studio"}
|
||||
return [
|
||||
ModelProviderRead(
|
||||
provider=provider,
|
||||
transport=spec.transport.value,
|
||||
discovery=spec.discovery,
|
||||
default_base_url=spec.default_base_url,
|
||||
base_url_required=spec.base_url_required,
|
||||
auth_style=spec.auth_style,
|
||||
local_only=provider in local_only,
|
||||
)
|
||||
for provider, spec in sorted(REGISTRY.items())
|
||||
]
|
||||
|
||||
|
||||
async def _get_search_space(session: AsyncSession, search_space_id: int) -> SearchSpace:
|
||||
result = await session.execute(
|
||||
select(SearchSpace).where(SearchSpace.id == search_space_id)
|
||||
)
|
||||
search_space = result.scalars().first()
|
||||
if not search_space:
|
||||
raise HTTPException(status_code=404, detail="Search space not found")
|
||||
return search_space
|
||||
|
||||
|
||||
async def _load_connection(session: AsyncSession, connection_id: int) -> Connection:
|
||||
result = await session.execute(
|
||||
select(Connection)
|
||||
.options(selectinload(Connection.models))
|
||||
.where(Connection.id == connection_id)
|
||||
)
|
||||
conn = result.scalars().first()
|
||||
if not conn:
|
||||
raise HTTPException(status_code=404, detail="Connection not found")
|
||||
return conn
|
||||
|
||||
|
||||
async def _assert_connection_access(
|
||||
session: AsyncSession,
|
||||
user: User,
|
||||
conn: Connection,
|
||||
permission: str = Permission.LLM_CONFIGS_CREATE.value,
|
||||
) -> None:
|
||||
if conn.search_space_id:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
conn.search_space_id,
|
||||
permission,
|
||||
"You don't have permission to manage model connections in this search space",
|
||||
)
|
||||
return
|
||||
if conn.user_id != user.id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Connection does not belong to user"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/global-model-connections", response_model=list[ConnectionRead])
|
||||
async def list_global_connections(user: User = Depends(current_active_user)):
|
||||
del user
|
||||
models_by_connection: dict[int, list[dict]] = {}
|
||||
for model in config.GLOBAL_MODELS:
|
||||
models_by_connection.setdefault(model["connection_id"], []).append(model)
|
||||
return [
|
||||
_connection_read(conn, models_by_connection.get(conn["id"], []))
|
||||
for conn in config.GLOBAL_CONNECTIONS
|
||||
]
|
||||
|
||||
|
||||
@router.get("/model-connections", response_model=list[ConnectionRead])
|
||||
async def list_connections(
|
||||
search_space_id: int | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
stmt = select(Connection).options(selectinload(Connection.models))
|
||||
if search_space_id is not None:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
search_space_id,
|
||||
Permission.LLM_CONFIGS_CREATE.value,
|
||||
"You don't have permission to view model connections in this search space",
|
||||
)
|
||||
stmt = stmt.where(Connection.search_space_id == search_space_id)
|
||||
else:
|
||||
stmt = stmt.where(Connection.user_id == user.id)
|
||||
result = await session.execute(stmt.order_by(Connection.id))
|
||||
return [
|
||||
_connection_read(conn, list(conn.models)) for conn in result.scalars().all()
|
||||
]
|
||||
|
||||
|
||||
@router.post("/model-connections", response_model=ConnectionRead)
|
||||
async def create_connection(
|
||||
data: ConnectionCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
if data.scope == ConnectionScope.GLOBAL:
|
||||
raise HTTPException(status_code=400, detail="GLOBAL connections are YAML-only")
|
||||
if data.scope == ConnectionScope.SEARCH_SPACE:
|
||||
if data.search_space_id is None:
|
||||
raise HTTPException(status_code=400, detail="search_space_id is required")
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
data.search_space_id,
|
||||
Permission.LLM_CONFIGS_CREATE.value,
|
||||
"You don't have permission to create model connections in this search space",
|
||||
)
|
||||
payload = data.model_dump(exclude={"search_space_id", "models"})
|
||||
|
||||
conn = Connection(
|
||||
**payload,
|
||||
search_space_id=data.search_space_id
|
||||
if data.scope == ConnectionScope.SEARCH_SPACE
|
||||
else None,
|
||||
user_id=user.id,
|
||||
)
|
||||
session.add(conn)
|
||||
await session.flush()
|
||||
|
||||
seen_model_ids: set[str] = set()
|
||||
for selection in data.models:
|
||||
model_id = selection.model_id.strip()
|
||||
if not model_id or model_id in seen_model_ids:
|
||||
continue
|
||||
seen_model_ids.add(model_id)
|
||||
session.add(_selection_to_model(conn, selection))
|
||||
|
||||
await session.commit()
|
||||
conn = await _load_connection(session, conn.id)
|
||||
await _default_unset_roles(session, conn, list(conn.models))
|
||||
await session.commit()
|
||||
conn = await _load_connection(session, conn.id)
|
||||
return _connection_read(conn, list(conn.models))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/model-connections/discover-preview", response_model=list[ModelPreviewRead]
|
||||
)
|
||||
async def preview_connection_models(
|
||||
data: ConnectionCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
if data.scope == ConnectionScope.SEARCH_SPACE and data.search_space_id is not None:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
data.search_space_id,
|
||||
Permission.LLM_CONFIGS_CREATE.value,
|
||||
"You don't have permission to create model connections in this search space",
|
||||
)
|
||||
|
||||
draft = Connection(
|
||||
provider=data.provider,
|
||||
base_url=data.base_url,
|
||||
api_key=data.api_key,
|
||||
extra=data.extra or {},
|
||||
scope=data.scope,
|
||||
enabled=data.enabled,
|
||||
search_space_id=data.search_space_id
|
||||
if data.scope == ConnectionScope.SEARCH_SPACE
|
||||
else None,
|
||||
user_id=user.id,
|
||||
)
|
||||
try:
|
||||
discovered = await discover_models(draft)
|
||||
except ModelDiscoveryError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
return [_preview_model_read(item) for item in discovered]
|
||||
|
||||
|
||||
@router.post("/model-connections/test-preview", response_model=VerifyConnectionResponse)
|
||||
async def test_preview_connection_model(
|
||||
data: ModelTestPreview,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
if data.scope == ConnectionScope.SEARCH_SPACE and data.search_space_id is not None:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
data.search_space_id,
|
||||
Permission.LLM_CONFIGS_CREATE.value,
|
||||
"You don't have permission to create model connections in this search space",
|
||||
)
|
||||
|
||||
model_id = data.model_id.strip()
|
||||
if not model_id:
|
||||
raise HTTPException(status_code=400, detail="model_id is required")
|
||||
|
||||
draft = Connection(
|
||||
provider=data.provider,
|
||||
base_url=data.base_url,
|
||||
api_key=data.api_key,
|
||||
extra=data.extra or {},
|
||||
scope=data.scope,
|
||||
enabled=data.enabled,
|
||||
search_space_id=data.search_space_id
|
||||
if data.scope == ConnectionScope.SEARCH_SPACE
|
||||
else None,
|
||||
user_id=user.id,
|
||||
)
|
||||
model = Model(
|
||||
connection_id=0,
|
||||
model_id=model_id,
|
||||
source=ModelSource.MANUAL,
|
||||
enabled=True,
|
||||
capabilities_override={},
|
||||
catalog={},
|
||||
)
|
||||
result = await test_model(draft, model)
|
||||
return VerifyConnectionResponse(
|
||||
status=result.status, ok=result.ok, message=result.message
|
||||
)
|
||||
|
||||
|
||||
@router.put("/model-connections/{connection_id}", response_model=ConnectionRead)
|
||||
async def update_connection(
|
||||
connection_id: int,
|
||||
data: ConnectionUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
conn = await _load_connection(session, connection_id)
|
||||
await _assert_connection_access(
|
||||
session, user, conn, Permission.LLM_CONFIGS_UPDATE.value
|
||||
)
|
||||
search_space_id = conn.search_space_id
|
||||
for key, value in data.model_dump(exclude_unset=True).items():
|
||||
setattr(conn, key, value)
|
||||
await session.commit()
|
||||
if search_space_id is not None:
|
||||
await _clear_invalid_roles(session, search_space_id)
|
||||
await session.commit()
|
||||
conn = await _load_connection(session, connection_id)
|
||||
return _connection_read(conn, list(conn.models))
|
||||
|
||||
|
||||
@router.delete("/model-connections/{connection_id}")
|
||||
async def delete_connection(
|
||||
connection_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
conn = await _load_connection(session, connection_id)
|
||||
await _assert_connection_access(
|
||||
session, user, conn, Permission.LLM_CONFIGS_DELETE.value
|
||||
)
|
||||
search_space_id = conn.search_space_id
|
||||
await session.delete(conn)
|
||||
await session.commit()
|
||||
if search_space_id is not None:
|
||||
await _clear_invalid_roles(session, search_space_id)
|
||||
await session.commit()
|
||||
return {"status": "deleted"}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/model-connections/{connection_id}/verify", response_model=VerifyConnectionResponse
|
||||
)
|
||||
async def verify_model_connection(
|
||||
connection_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
conn = await _load_connection(session, connection_id)
|
||||
await _assert_connection_access(
|
||||
session, user, conn, Permission.LLM_CONFIGS_CREATE.value
|
||||
)
|
||||
result = await verify_connection(conn)
|
||||
return VerifyConnectionResponse(
|
||||
status=result.status, ok=result.ok, message=result.message
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/model-connections/{connection_id}/discover", response_model=list[ModelRead]
|
||||
)
|
||||
async def discover_connection_models(
|
||||
connection_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
conn = await _load_connection(session, connection_id)
|
||||
await _assert_connection_access(
|
||||
session, user, conn, Permission.LLM_CONFIGS_CREATE.value
|
||||
)
|
||||
try:
|
||||
discovered = await discover_models(conn)
|
||||
except ModelDiscoveryError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
by_model_id = {model.model_id: model for model in conn.models}
|
||||
for item in discovered:
|
||||
db_model = by_model_id.get(item["model_id"])
|
||||
if db_model is None:
|
||||
db_model = Model(
|
||||
connection_id=conn.id,
|
||||
model_id=item["model_id"],
|
||||
display_name=item.get("display_name"),
|
||||
source=item["source"],
|
||||
capabilities_override={},
|
||||
enabled=False,
|
||||
catalog=item.get("metadata") or {},
|
||||
)
|
||||
_apply_model_facts(db_model, item)
|
||||
session.add(db_model)
|
||||
else:
|
||||
db_model.display_name = item.get("display_name") or db_model.display_name
|
||||
_apply_model_facts(db_model, item)
|
||||
db_model.catalog = item.get("metadata") or db_model.catalog
|
||||
await session.commit()
|
||||
conn = await _load_connection(session, connection_id)
|
||||
await _default_unset_roles(session, conn, list(conn.models))
|
||||
if conn.search_space_id is not None:
|
||||
await _clear_invalid_roles(session, conn.search_space_id)
|
||||
await session.commit()
|
||||
conn = await _load_connection(session, connection_id)
|
||||
return [_model_read(model) for model in conn.models]
|
||||
|
||||
|
||||
@router.post("/model-connections/{connection_id}/models", response_model=ModelRead)
|
||||
async def add_manual_model(
|
||||
connection_id: int,
|
||||
data: ModelCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
conn = await _load_connection(session, connection_id)
|
||||
await _assert_connection_access(
|
||||
session, user, conn, Permission.LLM_CONFIGS_UPDATE.value
|
||||
)
|
||||
|
||||
model_id = data.model_id.strip()
|
||||
if not model_id:
|
||||
raise HTTPException(status_code=400, detail="model_id is required")
|
||||
if any(existing.model_id == model_id for existing in conn.models):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Model already exists on this connection"
|
||||
)
|
||||
|
||||
capabilities = derive_capabilities(conn, model_id)
|
||||
model = Model(
|
||||
connection_id=conn.id,
|
||||
model_id=model_id,
|
||||
display_name=data.display_name or None,
|
||||
source=ModelSource.MANUAL,
|
||||
capabilities_override={},
|
||||
enabled=True,
|
||||
catalog={},
|
||||
)
|
||||
_apply_model_facts(model, capabilities)
|
||||
session.add(model)
|
||||
await session.commit()
|
||||
await session.refresh(model)
|
||||
conn = await _load_connection(session, connection_id)
|
||||
await _default_unset_roles(session, conn, list(conn.models))
|
||||
if conn.search_space_id is not None:
|
||||
await _clear_invalid_roles(session, conn.search_space_id)
|
||||
await session.commit()
|
||||
await session.refresh(model)
|
||||
return _model_read(model)
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/model-connections/{connection_id}/models", response_model=list[ModelRead]
|
||||
)
|
||||
async def bulk_update_models(
|
||||
connection_id: int,
|
||||
data: ModelsBulkUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
conn = await _load_connection(session, connection_id)
|
||||
await _assert_connection_access(
|
||||
session, user, conn, Permission.LLM_CONFIGS_UPDATE.value
|
||||
)
|
||||
search_space_id = conn.search_space_id
|
||||
|
||||
model_ids = set(data.model_ids)
|
||||
await session.execute(
|
||||
update(Model)
|
||||
.where(Model.connection_id == connection_id, Model.id.in_(model_ids))
|
||||
.values(enabled=data.enabled)
|
||||
)
|
||||
await session.commit()
|
||||
session.expire_all()
|
||||
if search_space_id is not None:
|
||||
await _clear_invalid_roles(session, search_space_id)
|
||||
await session.commit()
|
||||
session.expire_all()
|
||||
|
||||
result = await session.execute(
|
||||
select(Model)
|
||||
.where(Model.connection_id == connection_id, Model.id.in_(model_ids))
|
||||
.order_by(Model.id)
|
||||
)
|
||||
return [_model_read(model) for model in result.scalars().all()]
|
||||
|
||||
|
||||
@router.put("/models/{model_id}", response_model=ModelRead)
|
||||
async def update_model(
|
||||
model_id: int,
|
||||
data: ModelUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
result = await session.execute(
|
||||
select(Model)
|
||||
.options(selectinload(Model.connection))
|
||||
.where(Model.id == model_id)
|
||||
)
|
||||
model = result.scalars().first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
await _assert_connection_access(
|
||||
session, user, model.connection, Permission.LLM_CONFIGS_UPDATE.value
|
||||
)
|
||||
search_space_id = model.connection.search_space_id
|
||||
update = data.model_dump(exclude_unset=True)
|
||||
for key, value in update.items():
|
||||
setattr(model, key, value)
|
||||
await session.commit()
|
||||
await session.refresh(model)
|
||||
if search_space_id is not None:
|
||||
await _clear_invalid_roles(session, search_space_id)
|
||||
await session.commit()
|
||||
await session.refresh(model)
|
||||
return _model_read(model)
|
||||
|
||||
|
||||
@router.post("/models/{model_id}/test", response_model=VerifyConnectionResponse)
|
||||
async def test_connection_model(
|
||||
model_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
result = await session.execute(
|
||||
select(Model)
|
||||
.options(selectinload(Model.connection))
|
||||
.where(Model.id == model_id)
|
||||
)
|
||||
model = result.scalars().first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
await _assert_connection_access(
|
||||
session, user, model.connection, Permission.LLM_CONFIGS_UPDATE.value
|
||||
)
|
||||
result = await test_model(model.connection, model)
|
||||
await session.commit()
|
||||
return VerifyConnectionResponse(
|
||||
status=result.status, ok=result.ok, message=result.message
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/search-spaces/{search_space_id}/model-roles", response_model=ModelRolesRead
|
||||
)
|
||||
async def get_model_roles(
|
||||
search_space_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
search_space_id,
|
||||
Permission.LLM_CONFIGS_CREATE.value,
|
||||
"You don't have permission to view model roles in this search space",
|
||||
)
|
||||
search_space = await _clear_invalid_roles(session, search_space_id)
|
||||
await session.commit()
|
||||
await session.refresh(search_space)
|
||||
return ModelRolesRead(
|
||||
chat_model_id=search_space.chat_model_id,
|
||||
vision_model_id=search_space.vision_model_id,
|
||||
image_gen_model_id=search_space.image_gen_model_id,
|
||||
)
|
||||
|
||||
|
||||
@router.put(
|
||||
"/search-spaces/{search_space_id}/model-roles", response_model=ModelRolesRead
|
||||
)
|
||||
async def update_model_roles(
|
||||
search_space_id: int,
|
||||
data: ModelRolesUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
search_space_id,
|
||||
Permission.LLM_CONFIGS_UPDATE.value,
|
||||
"You don't have permission to update model roles in this search space",
|
||||
)
|
||||
search_space = await _get_search_space(session, search_space_id)
|
||||
updates = data.model_dump(exclude_unset=True)
|
||||
if "chat_model_id" in updates:
|
||||
previous_chat_model_id = search_space.chat_model_id
|
||||
next_chat_model_id = await _validate_role_model_id(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
model_id=updates["chat_model_id"],
|
||||
capability="chat",
|
||||
)
|
||||
search_space.chat_model_id = next_chat_model_id
|
||||
if next_chat_model_id != previous_chat_model_id:
|
||||
await session.execute(
|
||||
update(NewChatThread)
|
||||
.where(NewChatThread.search_space_id == search_space_id)
|
||||
.values(pinned_llm_config_id=None)
|
||||
)
|
||||
logger.info(
|
||||
"Cleared auto model pins for search_space_id=%s after chat_model_id change (%s -> %s)",
|
||||
search_space_id,
|
||||
previous_chat_model_id,
|
||||
next_chat_model_id,
|
||||
)
|
||||
if "vision_model_id" in updates:
|
||||
search_space.vision_model_id = await _validate_role_model_id(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
model_id=updates["vision_model_id"],
|
||||
capability="vision",
|
||||
)
|
||||
if "image_gen_model_id" in updates:
|
||||
search_space.image_gen_model_id = await _validate_role_model_id(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
model_id=updates["image_gen_model_id"],
|
||||
capability="image_gen",
|
||||
)
|
||||
await session.commit()
|
||||
await session.refresh(search_space)
|
||||
return ModelRolesRead(
|
||||
chat_model_id=search_space.chat_model_id,
|
||||
vision_model_id=search_space.vision_model_id,
|
||||
image_gen_model_id=search_space.image_gen_model_id,
|
||||
)
|
||||
|
|
@ -1741,12 +1741,11 @@ async def handle_new_chat(
|
|||
if not search_space:
|
||||
raise HTTPException(status_code=404, detail="Search space not found")
|
||||
|
||||
# Use agent_llm_id from search space for chat operations
|
||||
# Positive IDs load from NewLLMConfig database table
|
||||
# Negative IDs load from YAML global configs
|
||||
# Falls back to -1 (first global config) if not configured
|
||||
# Use the converged model-connections role for chat operations.
|
||||
# Positive IDs load Model + Connection rows; negative IDs load
|
||||
# virtual GLOBAL models; 0 means Auto.
|
||||
llm_config_id = (
|
||||
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
|
||||
search_space.chat_model_id if search_space.chat_model_id is not None else 0
|
||||
)
|
||||
|
||||
# Release the read-transaction so we don't hold ACCESS SHARE locks
|
||||
|
|
@ -2228,7 +2227,7 @@ async def regenerate_response(
|
|||
raise HTTPException(status_code=404, detail="Search space not found")
|
||||
|
||||
llm_config_id = (
|
||||
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
|
||||
search_space.chat_model_id if search_space.chat_model_id is not None else 0
|
||||
)
|
||||
|
||||
# Release the read-transaction so we don't hold ACCESS SHARE locks
|
||||
|
|
@ -2393,7 +2392,7 @@ async def resume_chat(
|
|||
raise HTTPException(status_code=404, detail="Search space not found")
|
||||
|
||||
llm_config_id = (
|
||||
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
|
||||
search_space.chat_model_id if search_space.chat_model_id is not None else 0
|
||||
)
|
||||
|
||||
decisions = [d.model_dump() for d in request.decisions]
|
||||
|
|
|
|||
|
|
@ -1,480 +0,0 @@
|
|||
"""
|
||||
API routes for NewLLMConfig CRUD operations.
|
||||
|
||||
NewLLMConfig combines model settings with prompt configuration:
|
||||
- LLM provider, model, API key, etc.
|
||||
- Configurable system instructions
|
||||
- Citation toggle
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
NewLLMConfig,
|
||||
Permission,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.prompts.default_system_instructions import get_default_system_instructions
|
||||
from app.schemas import (
|
||||
DefaultSystemInstructionsResponse,
|
||||
GlobalNewLLMConfigRead,
|
||||
NewLLMConfigCreate,
|
||||
NewLLMConfigRead,
|
||||
NewLLMConfigUpdate,
|
||||
)
|
||||
from app.services.llm_service import validate_llm_config
|
||||
from app.services.provider_capabilities import derive_supports_image_input
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _serialize_byok_config(config: NewLLMConfig) -> NewLLMConfigRead:
|
||||
"""Augment a BYOK chat config row with the derived ``supports_image_input``.
|
||||
|
||||
There is no DB column for ``supports_image_input`` — the value is
|
||||
resolved at the API boundary from LiteLLM's authoritative model map
|
||||
(default-allow on unknown). Returning ``NewLLMConfigRead`` here keeps
|
||||
the response shape consistent across list / detail / create / update
|
||||
endpoints without having to remember to set the field at every call
|
||||
site.
|
||||
"""
|
||||
provider_value = (
|
||||
config.provider.value
|
||||
if hasattr(config.provider, "value")
|
||||
else str(config.provider)
|
||||
)
|
||||
litellm_params = config.litellm_params or {}
|
||||
base_model = (
|
||||
litellm_params.get("base_model") if isinstance(litellm_params, dict) else None
|
||||
)
|
||||
supports_image_input = derive_supports_image_input(
|
||||
provider=provider_value,
|
||||
model_name=config.model_name,
|
||||
base_model=base_model,
|
||||
custom_provider=config.custom_provider,
|
||||
)
|
||||
# ``model_validate`` runs the Pydantic conversion using the ORM
|
||||
# attribute access path enabled by ``ConfigDict(from_attributes=True)``,
|
||||
# then we layer the derived field on. ``model_copy(update=...)`` keeps
|
||||
# the surface immutable from the caller's perspective.
|
||||
base_read = NewLLMConfigRead.model_validate(config)
|
||||
return base_read.model_copy(update={"supports_image_input": supports_image_input})
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Global Configs Routes
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.get("/global-new-llm-configs", response_model=list[GlobalNewLLMConfigRead])
|
||||
async def get_global_new_llm_configs(
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Get all available global NewLLMConfig configurations.
|
||||
These are pre-configured by the system administrator and available to all users.
|
||||
API keys are not exposed through this endpoint.
|
||||
|
||||
Includes:
|
||||
- Auto mode (ID 0): Uses LiteLLM Router for automatic load balancing
|
||||
- Global configs (negative IDs): Individual pre-configured LLM providers
|
||||
"""
|
||||
try:
|
||||
global_configs = config.GLOBAL_LLM_CONFIGS
|
||||
safe_configs = []
|
||||
|
||||
# Only include Auto mode if there are actual global configs to route to
|
||||
# Auto mode requires at least one global config with valid API key
|
||||
if global_configs and len(global_configs) > 0:
|
||||
safe_configs.append(
|
||||
{
|
||||
"id": 0,
|
||||
"name": "Auto (Fastest)",
|
||||
"description": "Automatically routes requests across available LLM providers for optimal performance and rate limit handling. Recommended for most users.",
|
||||
"provider": "AUTO",
|
||||
"custom_provider": None,
|
||||
"model_name": "auto",
|
||||
"api_base": None,
|
||||
"litellm_params": {},
|
||||
"system_instructions": "",
|
||||
"use_default_system_instructions": True,
|
||||
"citations_enabled": True,
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
"billing_tier": "free",
|
||||
"is_premium": False,
|
||||
"anonymous_enabled": False,
|
||||
"seo_enabled": False,
|
||||
"seo_slug": None,
|
||||
"seo_title": None,
|
||||
"seo_description": None,
|
||||
"quota_reserve_tokens": None,
|
||||
# Auto routes across the configured pool, which usually
|
||||
# includes at least one vision-capable deployment, so
|
||||
# treat Auto as image-capable. The router itself will
|
||||
# still pick a vision-capable deployment for messages
|
||||
# carrying image_url blocks (LiteLLM Router falls back
|
||||
# on ``404`` per its ``allowed_fails`` policy).
|
||||
"supports_image_input": True,
|
||||
}
|
||||
)
|
||||
|
||||
# Add individual global configs
|
||||
for cfg in global_configs:
|
||||
# Capability resolution: explicit value (YAML override or OR
|
||||
# `_supports_image_input(model)` payload baked in by the
|
||||
# OpenRouter integration service) wins. Fall back to the
|
||||
# LiteLLM-driven helper which default-allows on unknown so
|
||||
# we don't hide vision-capable models that happen to lack a
|
||||
# YAML annotation. The streaming task safety net is the
|
||||
# only place a False ever blocks.
|
||||
if "supports_image_input" in cfg:
|
||||
supports_image_input = bool(cfg.get("supports_image_input"))
|
||||
else:
|
||||
cfg_litellm_params = cfg.get("litellm_params") or {}
|
||||
cfg_base_model = (
|
||||
cfg_litellm_params.get("base_model")
|
||||
if isinstance(cfg_litellm_params, dict)
|
||||
else None
|
||||
)
|
||||
supports_image_input = derive_supports_image_input(
|
||||
provider=cfg.get("provider"),
|
||||
model_name=cfg.get("model_name"),
|
||||
base_model=cfg_base_model,
|
||||
custom_provider=cfg.get("custom_provider"),
|
||||
)
|
||||
|
||||
safe_config = {
|
||||
"id": cfg.get("id"),
|
||||
"name": cfg.get("name"),
|
||||
"description": cfg.get("description"),
|
||||
"provider": cfg.get("provider"),
|
||||
"custom_provider": cfg.get("custom_provider"),
|
||||
"model_name": cfg.get("model_name"),
|
||||
"api_base": cfg.get("api_base") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
# New prompt configuration fields
|
||||
"system_instructions": cfg.get("system_instructions", ""),
|
||||
"use_default_system_instructions": cfg.get(
|
||||
"use_default_system_instructions", True
|
||||
),
|
||||
"citations_enabled": cfg.get("citations_enabled", True),
|
||||
"is_global": True,
|
||||
"billing_tier": cfg.get("billing_tier", "free"),
|
||||
"is_premium": cfg.get("billing_tier", "free") == "premium",
|
||||
"anonymous_enabled": cfg.get("anonymous_enabled", False),
|
||||
"seo_enabled": cfg.get("seo_enabled", False),
|
||||
"seo_slug": cfg.get("seo_slug"),
|
||||
"seo_title": cfg.get("seo_title"),
|
||||
"seo_description": cfg.get("seo_description"),
|
||||
"quota_reserve_tokens": cfg.get("quota_reserve_tokens"),
|
||||
"supports_image_input": supports_image_input,
|
||||
}
|
||||
safe_configs.append(safe_config)
|
||||
|
||||
return safe_configs
|
||||
except Exception as e:
|
||||
logger.exception("Failed to fetch global NewLLMConfigs")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch global configurations: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# CRUD Routes
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.post("/new-llm-configs", response_model=NewLLMConfigRead)
|
||||
async def create_new_llm_config(
|
||||
config_data: NewLLMConfigCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Create a new NewLLMConfig for a search space.
|
||||
Requires LLM_CONFIGS_CREATE permission.
|
||||
"""
|
||||
try:
|
||||
# Verify user has permission
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
config_data.search_space_id,
|
||||
Permission.LLM_CONFIGS_CREATE.value,
|
||||
"You don't have permission to create LLM configurations in this search space",
|
||||
)
|
||||
|
||||
# Validate the LLM configuration by making a test API call
|
||||
is_valid, error_message = await validate_llm_config(
|
||||
provider=config_data.provider.value,
|
||||
model_name=config_data.model_name,
|
||||
api_key=config_data.api_key,
|
||||
api_base=config_data.api_base,
|
||||
custom_provider=config_data.custom_provider,
|
||||
litellm_params=config_data.litellm_params,
|
||||
)
|
||||
|
||||
if not is_valid:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid LLM configuration: {error_message}",
|
||||
)
|
||||
|
||||
# Create the config with user association
|
||||
db_config = NewLLMConfig(**config_data.model_dump(), user_id=user.id)
|
||||
session.add(db_config)
|
||||
await session.commit()
|
||||
await session.refresh(db_config)
|
||||
|
||||
return _serialize_byok_config(db_config)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.exception("Failed to create NewLLMConfig")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to create configuration: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/new-llm-configs", response_model=list[NewLLMConfigRead])
|
||||
async def list_new_llm_configs(
|
||||
search_space_id: int,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Get all NewLLMConfigs for a search space.
|
||||
Requires LLM_CONFIGS_READ permission.
|
||||
"""
|
||||
try:
|
||||
# Verify user has permission
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
search_space_id,
|
||||
Permission.LLM_CONFIGS_READ.value,
|
||||
"You don't have permission to view LLM configurations in this search space",
|
||||
)
|
||||
|
||||
result = await session.execute(
|
||||
select(NewLLMConfig)
|
||||
.filter(NewLLMConfig.search_space_id == search_space_id)
|
||||
.order_by(NewLLMConfig.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
return [_serialize_byok_config(cfg) for cfg in result.scalars().all()]
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Failed to list NewLLMConfigs")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch configurations: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get(
|
||||
"/new-llm-configs/default-system-instructions",
|
||||
response_model=DefaultSystemInstructionsResponse,
|
||||
)
|
||||
async def get_default_system_instructions_endpoint(
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Get the default SURFSENSE_SYSTEM_INSTRUCTIONS template.
|
||||
Useful for pre-populating the UI when creating a new configuration.
|
||||
"""
|
||||
return DefaultSystemInstructionsResponse(
|
||||
default_system_instructions=get_default_system_instructions()
|
||||
)
|
||||
|
||||
|
||||
@router.get("/new-llm-configs/{config_id}", response_model=NewLLMConfigRead)
|
||||
async def get_new_llm_config(
|
||||
config_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Get a specific NewLLMConfig by ID.
|
||||
Requires LLM_CONFIGS_READ permission.
|
||||
"""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(NewLLMConfig).filter(NewLLMConfig.id == config_id)
|
||||
)
|
||||
config = result.scalars().first()
|
||||
|
||||
if not config:
|
||||
raise HTTPException(status_code=404, detail="Configuration not found")
|
||||
|
||||
# Verify user has permission
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
config.search_space_id,
|
||||
Permission.LLM_CONFIGS_READ.value,
|
||||
"You don't have permission to view LLM configurations in this search space",
|
||||
)
|
||||
|
||||
return _serialize_byok_config(config)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Failed to get NewLLMConfig")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch configuration: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.put("/new-llm-configs/{config_id}", response_model=NewLLMConfigRead)
|
||||
async def update_new_llm_config(
|
||||
config_id: int,
|
||||
update_data: NewLLMConfigUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Update an existing NewLLMConfig.
|
||||
Requires LLM_CONFIGS_UPDATE permission.
|
||||
"""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(NewLLMConfig).filter(NewLLMConfig.id == config_id)
|
||||
)
|
||||
config = result.scalars().first()
|
||||
|
||||
if not config:
|
||||
raise HTTPException(status_code=404, detail="Configuration not found")
|
||||
|
||||
# Verify user has permission
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
config.search_space_id,
|
||||
Permission.LLM_CONFIGS_UPDATE.value,
|
||||
"You don't have permission to update LLM configurations in this search space",
|
||||
)
|
||||
|
||||
update_dict = update_data.model_dump(exclude_unset=True)
|
||||
|
||||
# If updating LLM settings, validate them
|
||||
if any(
|
||||
key in update_dict
|
||||
for key in [
|
||||
"provider",
|
||||
"model_name",
|
||||
"api_key",
|
||||
"api_base",
|
||||
"custom_provider",
|
||||
"litellm_params",
|
||||
]
|
||||
):
|
||||
# Build the validation config from existing + updates
|
||||
validation_config = {
|
||||
"provider": update_dict.get("provider", config.provider).value
|
||||
if hasattr(update_dict.get("provider", config.provider), "value")
|
||||
else update_dict.get("provider", config.provider.value),
|
||||
"model_name": update_dict.get("model_name", config.model_name),
|
||||
"api_key": update_dict.get("api_key", config.api_key),
|
||||
"api_base": update_dict.get("api_base", config.api_base),
|
||||
"custom_provider": update_dict.get(
|
||||
"custom_provider", config.custom_provider
|
||||
),
|
||||
"litellm_params": update_dict.get(
|
||||
"litellm_params", config.litellm_params
|
||||
),
|
||||
}
|
||||
|
||||
is_valid, error_message = await validate_llm_config(
|
||||
provider=validation_config["provider"],
|
||||
model_name=validation_config["model_name"],
|
||||
api_key=validation_config["api_key"],
|
||||
api_base=validation_config["api_base"],
|
||||
custom_provider=validation_config["custom_provider"],
|
||||
litellm_params=validation_config["litellm_params"],
|
||||
)
|
||||
|
||||
if not is_valid:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid LLM configuration: {error_message}",
|
||||
)
|
||||
|
||||
# Apply updates
|
||||
for key, value in update_dict.items():
|
||||
setattr(config, key, value)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(config)
|
||||
|
||||
return _serialize_byok_config(config)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.exception("Failed to update NewLLMConfig")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to update configuration: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.delete("/new-llm-configs/{config_id}", response_model=dict)
|
||||
async def delete_new_llm_config(
|
||||
config_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Delete a NewLLMConfig.
|
||||
Requires LLM_CONFIGS_DELETE permission.
|
||||
"""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(NewLLMConfig).filter(NewLLMConfig.id == config_id)
|
||||
)
|
||||
config = result.scalars().first()
|
||||
|
||||
if not config:
|
||||
raise HTTPException(status_code=404, detail="Configuration not found")
|
||||
|
||||
# Verify user has permission
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
config.search_space_id,
|
||||
Permission.LLM_CONFIGS_DELETE.value,
|
||||
"You don't have permission to delete LLM configurations in this search space",
|
||||
)
|
||||
|
||||
await session.delete(config)
|
||||
await session.commit()
|
||||
|
||||
return {"message": "Configuration deleted successfully", "id": config_id}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.exception("Failed to delete NewLLMConfig")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to delete configuration: {e!s}"
|
||||
) from e
|
||||
|
|
@ -1,27 +1,20 @@
|
|||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import func, update
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
ImageGenerationConfig,
|
||||
NewChatThread,
|
||||
NewLLMConfig,
|
||||
Permission,
|
||||
SearchSpace,
|
||||
SearchSpaceMembership,
|
||||
SearchSpaceRole,
|
||||
User,
|
||||
VisionLLMConfig,
|
||||
get_async_session,
|
||||
get_default_roles_config,
|
||||
)
|
||||
from app.schemas import (
|
||||
LLMPreferencesRead,
|
||||
LLMPreferencesUpdate,
|
||||
SearchSpaceCreate,
|
||||
SearchSpaceRead,
|
||||
SearchSpaceUpdate,
|
||||
|
|
@ -377,357 +370,6 @@ async def delete_search_space(
|
|||
) from e
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# LLM Preferences Routes
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def _get_llm_config_by_id(
|
||||
session: AsyncSession, config_id: int | None
|
||||
) -> dict | None:
|
||||
"""
|
||||
Get an LLM config by ID as a dictionary. Returns database config for positive IDs,
|
||||
global config for negative IDs, Auto mode config for ID 0, or None if ID is None.
|
||||
"""
|
||||
if config_id is None:
|
||||
return None
|
||||
|
||||
# Auto mode (ID 0) - uses LiteLLM Router for load balancing
|
||||
if config_id == 0:
|
||||
return {
|
||||
"id": 0,
|
||||
"name": "Auto (Fastest)",
|
||||
"description": "Automatically routes requests across available LLM providers for optimal performance and rate limit handling",
|
||||
"provider": "AUTO",
|
||||
"custom_provider": None,
|
||||
"model_name": "auto",
|
||||
"api_base": None,
|
||||
"litellm_params": {},
|
||||
"system_instructions": "",
|
||||
"use_default_system_instructions": True,
|
||||
"citations_enabled": True,
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
}
|
||||
|
||||
if config_id < 0:
|
||||
# Global config - find from YAML
|
||||
global_configs = config.GLOBAL_LLM_CONFIGS
|
||||
for cfg in global_configs:
|
||||
if cfg.get("id") == config_id:
|
||||
return {
|
||||
"id": cfg.get("id"),
|
||||
"name": cfg.get("name"),
|
||||
"description": cfg.get("description"),
|
||||
"provider": cfg.get("provider"),
|
||||
"custom_provider": cfg.get("custom_provider"),
|
||||
"model_name": cfg.get("model_name"),
|
||||
"api_base": cfg.get("api_base"),
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"system_instructions": cfg.get("system_instructions", ""),
|
||||
"use_default_system_instructions": cfg.get(
|
||||
"use_default_system_instructions", True
|
||||
),
|
||||
"citations_enabled": cfg.get("citations_enabled", True),
|
||||
"is_global": True,
|
||||
}
|
||||
return None
|
||||
else:
|
||||
# Database config - convert to dict
|
||||
result = await session.execute(
|
||||
select(NewLLMConfig).filter(NewLLMConfig.id == config_id)
|
||||
)
|
||||
db_config = result.scalars().first()
|
||||
if db_config:
|
||||
return {
|
||||
"id": db_config.id,
|
||||
"name": db_config.name,
|
||||
"description": db_config.description,
|
||||
"provider": db_config.provider.value if db_config.provider else None,
|
||||
"custom_provider": db_config.custom_provider,
|
||||
"model_name": db_config.model_name,
|
||||
"api_key": db_config.api_key,
|
||||
"api_base": db_config.api_base,
|
||||
"litellm_params": db_config.litellm_params or {},
|
||||
"system_instructions": db_config.system_instructions or "",
|
||||
"use_default_system_instructions": db_config.use_default_system_instructions,
|
||||
"citations_enabled": db_config.citations_enabled,
|
||||
"created_at": db_config.created_at.isoformat()
|
||||
if db_config.created_at
|
||||
else None,
|
||||
"search_space_id": db_config.search_space_id,
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
async def _get_image_gen_config_by_id(
|
||||
session: AsyncSession, config_id: int | None
|
||||
) -> dict | None:
|
||||
"""
|
||||
Get an image generation config by ID as a dictionary.
|
||||
Returns Auto mode for ID 0, global config for negative IDs,
|
||||
DB ImageGenerationConfig for positive IDs, or None.
|
||||
"""
|
||||
if config_id is None:
|
||||
return None
|
||||
|
||||
if config_id == 0:
|
||||
return {
|
||||
"id": 0,
|
||||
"name": "Auto (Fastest)",
|
||||
"description": "Automatically routes requests across available image generation providers",
|
||||
"provider": "AUTO",
|
||||
"model_name": "auto",
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
"billing_tier": "free",
|
||||
}
|
||||
|
||||
if config_id < 0:
|
||||
for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS:
|
||||
if cfg.get("id") == config_id:
|
||||
return {
|
||||
"id": cfg.get("id"),
|
||||
"name": cfg.get("name"),
|
||||
"description": cfg.get("description"),
|
||||
"provider": cfg.get("provider"),
|
||||
"custom_provider": cfg.get("custom_provider"),
|
||||
"model_name": cfg.get("model_name"),
|
||||
"api_base": cfg.get("api_base") or None,
|
||||
"api_version": cfg.get("api_version") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
"billing_tier": cfg.get("billing_tier", "free"),
|
||||
}
|
||||
return None
|
||||
|
||||
# Positive ID: query ImageGenerationConfig table
|
||||
result = await session.execute(
|
||||
select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id)
|
||||
)
|
||||
db_config = result.scalars().first()
|
||||
if db_config:
|
||||
return {
|
||||
"id": db_config.id,
|
||||
"name": db_config.name,
|
||||
"description": db_config.description,
|
||||
"provider": db_config.provider.value if db_config.provider else None,
|
||||
"custom_provider": db_config.custom_provider,
|
||||
"model_name": db_config.model_name,
|
||||
"api_base": db_config.api_base,
|
||||
"api_version": db_config.api_version,
|
||||
"litellm_params": db_config.litellm_params or {},
|
||||
"created_at": db_config.created_at.isoformat()
|
||||
if db_config.created_at
|
||||
else None,
|
||||
"search_space_id": db_config.search_space_id,
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
async def _get_vision_llm_config_by_id(
|
||||
session: AsyncSession, config_id: int | None
|
||||
) -> dict | None:
|
||||
if config_id is None:
|
||||
return None
|
||||
|
||||
if config_id == 0:
|
||||
return {
|
||||
"id": 0,
|
||||
"name": "Auto (Fastest)",
|
||||
"description": "Automatically routes requests across available vision LLM providers",
|
||||
"provider": "AUTO",
|
||||
"model_name": "auto",
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
"billing_tier": "free",
|
||||
}
|
||||
|
||||
if config_id < 0:
|
||||
for cfg in config.GLOBAL_VISION_LLM_CONFIGS:
|
||||
if cfg.get("id") == config_id:
|
||||
return {
|
||||
"id": cfg.get("id"),
|
||||
"name": cfg.get("name"),
|
||||
"description": cfg.get("description"),
|
||||
"provider": cfg.get("provider"),
|
||||
"custom_provider": cfg.get("custom_provider"),
|
||||
"model_name": cfg.get("model_name"),
|
||||
"api_base": cfg.get("api_base") or None,
|
||||
"api_version": cfg.get("api_version") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
"billing_tier": cfg.get("billing_tier", "free"),
|
||||
}
|
||||
return None
|
||||
|
||||
result = await session.execute(
|
||||
select(VisionLLMConfig).filter(VisionLLMConfig.id == config_id)
|
||||
)
|
||||
db_config = result.scalars().first()
|
||||
if db_config:
|
||||
return {
|
||||
"id": db_config.id,
|
||||
"name": db_config.name,
|
||||
"description": db_config.description,
|
||||
"provider": db_config.provider.value if db_config.provider else None,
|
||||
"custom_provider": db_config.custom_provider,
|
||||
"model_name": db_config.model_name,
|
||||
"api_base": db_config.api_base,
|
||||
"api_version": db_config.api_version,
|
||||
"litellm_params": db_config.litellm_params or {},
|
||||
"created_at": db_config.created_at.isoformat()
|
||||
if db_config.created_at
|
||||
else None,
|
||||
"search_space_id": db_config.search_space_id,
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
@router.get(
|
||||
"/search-spaces/{search_space_id}/llm-preferences",
|
||||
response_model=LLMPreferencesRead,
|
||||
)
|
||||
async def get_llm_preferences(
|
||||
search_space_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Get LLM preferences (role assignments) for a search space.
|
||||
Requires LLM_CONFIGS_READ permission.
|
||||
"""
|
||||
try:
|
||||
# Check permission
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
search_space_id,
|
||||
Permission.LLM_CONFIGS_READ.value,
|
||||
"You don't have permission to view LLM preferences",
|
||||
)
|
||||
|
||||
result = await session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == search_space_id)
|
||||
)
|
||||
search_space = result.scalars().first()
|
||||
|
||||
if not search_space:
|
||||
raise HTTPException(status_code=404, detail="Search space not found")
|
||||
|
||||
# Get full config objects for each role
|
||||
agent_llm = await _get_llm_config_by_id(session, search_space.agent_llm_id)
|
||||
image_generation_config = await _get_image_gen_config_by_id(
|
||||
session, search_space.image_generation_config_id
|
||||
)
|
||||
vision_llm_config = await _get_vision_llm_config_by_id(
|
||||
session, search_space.vision_llm_config_id
|
||||
)
|
||||
|
||||
return LLMPreferencesRead(
|
||||
agent_llm_id=search_space.agent_llm_id,
|
||||
image_generation_config_id=search_space.image_generation_config_id,
|
||||
vision_llm_config_id=search_space.vision_llm_config_id,
|
||||
agent_llm=agent_llm,
|
||||
image_generation_config=image_generation_config,
|
||||
vision_llm_config=vision_llm_config,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Failed to get LLM preferences")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to get LLM preferences: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.put(
|
||||
"/search-spaces/{search_space_id}/llm-preferences",
|
||||
response_model=LLMPreferencesRead,
|
||||
)
|
||||
async def update_llm_preferences(
|
||||
search_space_id: int,
|
||||
preferences: LLMPreferencesUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Update LLM preferences (role assignments) for a search space.
|
||||
Requires LLM_CONFIGS_UPDATE permission.
|
||||
"""
|
||||
try:
|
||||
# Check permission
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
search_space_id,
|
||||
Permission.LLM_CONFIGS_UPDATE.value,
|
||||
"You don't have permission to update LLM preferences",
|
||||
)
|
||||
|
||||
result = await session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == search_space_id)
|
||||
)
|
||||
search_space = result.scalars().first()
|
||||
|
||||
if not search_space:
|
||||
raise HTTPException(status_code=404, detail="Search space not found")
|
||||
|
||||
# Update preferences
|
||||
update_data = preferences.model_dump(exclude_unset=True)
|
||||
previous_agent_llm_id = search_space.agent_llm_id
|
||||
for key, value in update_data.items():
|
||||
setattr(search_space, key, value)
|
||||
|
||||
agent_llm_changed = (
|
||||
"agent_llm_id" in update_data
|
||||
and update_data["agent_llm_id"] != previous_agent_llm_id
|
||||
)
|
||||
if agent_llm_changed:
|
||||
await session.execute(
|
||||
update(NewChatThread)
|
||||
.where(NewChatThread.search_space_id == search_space_id)
|
||||
.values(pinned_llm_config_id=None)
|
||||
)
|
||||
logger.info(
|
||||
"Cleared auto model pins for search_space_id=%s after agent_llm_id change (%s -> %s)",
|
||||
search_space_id,
|
||||
previous_agent_llm_id,
|
||||
update_data["agent_llm_id"],
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(search_space)
|
||||
|
||||
# Get full config objects for response
|
||||
agent_llm = await _get_llm_config_by_id(session, search_space.agent_llm_id)
|
||||
image_generation_config = await _get_image_gen_config_by_id(
|
||||
session, search_space.image_generation_config_id
|
||||
)
|
||||
vision_llm_config = await _get_vision_llm_config_by_id(
|
||||
session, search_space.vision_llm_config_id
|
||||
)
|
||||
|
||||
return LLMPreferencesRead(
|
||||
agent_llm_id=search_space.agent_llm_id,
|
||||
image_generation_config_id=search_space.image_generation_config_id,
|
||||
vision_llm_config_id=search_space.vision_llm_config_id,
|
||||
agent_llm=agent_llm,
|
||||
image_generation_config=image_generation_config,
|
||||
vision_llm_config=vision_llm_config,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.exception("Failed to update LLM preferences")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to update LLM preferences: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/searchspaces/{search_space_id}/snapshots")
|
||||
async def list_search_space_snapshots(
|
||||
search_space_id: int,
|
||||
|
|
|
|||
|
|
@ -1,304 +0,0 @@
|
|||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
Permission,
|
||||
User,
|
||||
VisionLLMConfig,
|
||||
get_async_session,
|
||||
)
|
||||
from app.schemas import (
|
||||
GlobalVisionLLMConfigRead,
|
||||
VisionLLMConfigCreate,
|
||||
VisionLLMConfigRead,
|
||||
VisionLLMConfigUpdate,
|
||||
)
|
||||
from app.services.vision_model_list_service import get_vision_model_list
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Vision Model Catalogue (from OpenRouter, filtered for image-input models)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class VisionModelListItem(BaseModel):
|
||||
value: str
|
||||
label: str
|
||||
provider: str
|
||||
context_window: str | None = None
|
||||
|
||||
|
||||
@router.get("/vision-models", response_model=list[VisionModelListItem])
|
||||
async def list_vision_models(
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Return vision-capable models sourced from OpenRouter (filtered by image input)."""
|
||||
try:
|
||||
return await get_vision_model_list()
|
||||
except Exception as e:
|
||||
logger.exception("Failed to fetch vision model list")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch vision model list: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Global Vision LLM Configs (from YAML)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/global-vision-llm-configs",
|
||||
response_model=list[GlobalVisionLLMConfigRead],
|
||||
)
|
||||
async def get_global_vision_llm_configs(
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
global_configs = config.GLOBAL_VISION_LLM_CONFIGS
|
||||
safe_configs = []
|
||||
|
||||
if global_configs and len(global_configs) > 0:
|
||||
safe_configs.append(
|
||||
{
|
||||
"id": 0,
|
||||
"name": "Auto (Fastest)",
|
||||
"description": "Automatically routes across available vision LLM providers.",
|
||||
"provider": "AUTO",
|
||||
"custom_provider": None,
|
||||
"model_name": "auto",
|
||||
"api_base": None,
|
||||
"api_version": None,
|
||||
"litellm_params": {},
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
# Auto mode treated as free until per-deployment billing-tier
|
||||
# surfacing lands; see ``get_vision_llm`` for parity.
|
||||
"billing_tier": "free",
|
||||
"is_premium": False,
|
||||
}
|
||||
)
|
||||
|
||||
for cfg in global_configs:
|
||||
billing_tier = str(cfg.get("billing_tier", "free")).lower()
|
||||
safe_configs.append(
|
||||
{
|
||||
"id": cfg.get("id"),
|
||||
"name": cfg.get("name"),
|
||||
"description": cfg.get("description"),
|
||||
"provider": cfg.get("provider"),
|
||||
"custom_provider": cfg.get("custom_provider"),
|
||||
"model_name": cfg.get("model_name"),
|
||||
"api_base": cfg.get("api_base") or None,
|
||||
"api_version": cfg.get("api_version") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
"billing_tier": billing_tier,
|
||||
# Mirror chat (``new_llm_config_routes``) so the new-chat
|
||||
# selector's premium badge logic keys off the same
|
||||
# field across chat / image / vision tabs.
|
||||
"is_premium": billing_tier == "premium",
|
||||
"quota_reserve_tokens": cfg.get("quota_reserve_tokens"),
|
||||
"input_cost_per_token": cfg.get("input_cost_per_token"),
|
||||
"output_cost_per_token": cfg.get("output_cost_per_token"),
|
||||
}
|
||||
)
|
||||
|
||||
return safe_configs
|
||||
except Exception as e:
|
||||
logger.exception("Failed to fetch global vision LLM configs")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch configs: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# VisionLLMConfig CRUD
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.post("/vision-llm-configs", response_model=VisionLLMConfigRead)
|
||||
async def create_vision_llm_config(
|
||||
config_data: VisionLLMConfigCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
config_data.search_space_id,
|
||||
Permission.VISION_CONFIGS_CREATE.value,
|
||||
"You don't have permission to create vision LLM configs in this search space",
|
||||
)
|
||||
|
||||
db_config = VisionLLMConfig(**config_data.model_dump(), user_id=user.id)
|
||||
session.add(db_config)
|
||||
await session.commit()
|
||||
await session.refresh(db_config)
|
||||
return db_config
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.exception("Failed to create VisionLLMConfig")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to create config: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/vision-llm-configs", response_model=list[VisionLLMConfigRead])
|
||||
async def list_vision_llm_configs(
|
||||
search_space_id: int,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
search_space_id,
|
||||
Permission.VISION_CONFIGS_READ.value,
|
||||
"You don't have permission to view vision LLM configs in this search space",
|
||||
)
|
||||
|
||||
result = await session.execute(
|
||||
select(VisionLLMConfig)
|
||||
.filter(VisionLLMConfig.search_space_id == search_space_id)
|
||||
.order_by(VisionLLMConfig.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Failed to list VisionLLMConfigs")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch configs: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/vision-llm-configs/{config_id}", response_model=VisionLLMConfigRead)
|
||||
async def get_vision_llm_config(
|
||||
config_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(VisionLLMConfig).filter(VisionLLMConfig.id == config_id)
|
||||
)
|
||||
db_config = result.scalars().first()
|
||||
if not db_config:
|
||||
raise HTTPException(status_code=404, detail="Config not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
db_config.search_space_id,
|
||||
Permission.VISION_CONFIGS_READ.value,
|
||||
"You don't have permission to view vision LLM configs in this search space",
|
||||
)
|
||||
return db_config
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Failed to get VisionLLMConfig")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch config: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.put("/vision-llm-configs/{config_id}", response_model=VisionLLMConfigRead)
|
||||
async def update_vision_llm_config(
|
||||
config_id: int,
|
||||
update_data: VisionLLMConfigUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(VisionLLMConfig).filter(VisionLLMConfig.id == config_id)
|
||||
)
|
||||
db_config = result.scalars().first()
|
||||
if not db_config:
|
||||
raise HTTPException(status_code=404, detail="Config not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
db_config.search_space_id,
|
||||
Permission.VISION_CONFIGS_CREATE.value,
|
||||
"You don't have permission to update vision LLM configs in this search space",
|
||||
)
|
||||
|
||||
for key, value in update_data.model_dump(exclude_unset=True).items():
|
||||
setattr(db_config, key, value)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(db_config)
|
||||
return db_config
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.exception("Failed to update VisionLLMConfig")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to update config: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.delete("/vision-llm-configs/{config_id}", response_model=dict)
|
||||
async def delete_vision_llm_config(
|
||||
config_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(VisionLLMConfig).filter(VisionLLMConfig.id == config_id)
|
||||
)
|
||||
db_config = result.scalars().first()
|
||||
if not db_config:
|
||||
raise HTTPException(status_code=404, detail="Config not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
db_config.search_space_id,
|
||||
Permission.VISION_CONFIGS_DELETE.value,
|
||||
"You don't have permission to delete vision LLM configs in this search space",
|
||||
)
|
||||
|
||||
await session.delete(db_config)
|
||||
await session.commit()
|
||||
return {
|
||||
"message": "Vision LLM config deleted successfully",
|
||||
"id": config_id,
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.exception("Failed to delete VisionLLMConfig")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to delete config: {e!s}"
|
||||
) from e
|
||||
|
|
@ -34,16 +34,27 @@ from .folders import (
|
|||
)
|
||||
from .google_drive import DriveItem, GoogleDriveIndexingOptions, GoogleDriveIndexRequest
|
||||
from .image_generation import (
|
||||
GlobalImageGenConfigRead,
|
||||
ImageGenerationConfigCreate,
|
||||
ImageGenerationConfigPublic,
|
||||
ImageGenerationConfigRead,
|
||||
ImageGenerationConfigUpdate,
|
||||
ImageGenerationCreate,
|
||||
ImageGenerationListRead,
|
||||
ImageGenerationRead,
|
||||
)
|
||||
from .logs import LogBase, LogCreate, LogFilter, LogRead, LogUpdate
|
||||
from .model_connections import (
|
||||
ConnectionCreate,
|
||||
ConnectionRead,
|
||||
ConnectionUpdate,
|
||||
ModelCreate,
|
||||
ModelPreviewRead,
|
||||
ModelProviderRead,
|
||||
ModelRead,
|
||||
ModelRolesRead,
|
||||
ModelRolesUpdate,
|
||||
ModelsBulkUpdate,
|
||||
ModelSelection,
|
||||
ModelTestPreview,
|
||||
ModelUpdate,
|
||||
VerifyConnectionResponse,
|
||||
)
|
||||
from .new_chat import (
|
||||
ChatMessage,
|
||||
NewChatMessageAppend,
|
||||
|
|
@ -58,16 +69,6 @@ from .new_chat import (
|
|||
ThreadListItem,
|
||||
ThreadListResponse,
|
||||
)
|
||||
from .new_llm_config import (
|
||||
DefaultSystemInstructionsResponse,
|
||||
GlobalNewLLMConfigRead,
|
||||
LLMPreferencesRead,
|
||||
LLMPreferencesUpdate,
|
||||
NewLLMConfigCreate,
|
||||
NewLLMConfigPublic,
|
||||
NewLLMConfigRead,
|
||||
NewLLMConfigUpdate,
|
||||
)
|
||||
from .rbac_schemas import (
|
||||
InviteAcceptRequest,
|
||||
InviteAcceptResponse,
|
||||
|
|
@ -126,13 +127,6 @@ from .video_presentations import (
|
|||
VideoPresentationRead,
|
||||
VideoPresentationUpdate,
|
||||
)
|
||||
from .vision_llm import (
|
||||
GlobalVisionLLMConfigRead,
|
||||
VisionLLMConfigCreate,
|
||||
VisionLLMConfigPublic,
|
||||
VisionLLMConfigRead,
|
||||
VisionLLMConfigUpdate,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Folder schemas
|
||||
|
|
@ -144,12 +138,15 @@ __all__ = [
|
|||
"ChunkCreate",
|
||||
"ChunkRead",
|
||||
"ChunkUpdate",
|
||||
# Model connection schemas
|
||||
"ConnectionCreate",
|
||||
"ConnectionRead",
|
||||
"ConnectionUpdate",
|
||||
"CreateCreditCheckoutSessionRequest",
|
||||
"CreateCreditCheckoutSessionResponse",
|
||||
"CreditPurchaseHistoryResponse",
|
||||
"CreditPurchaseRead",
|
||||
"CreditStripeStatusResponse",
|
||||
"DefaultSystemInstructionsResponse",
|
||||
# Document schemas
|
||||
"DocumentBase",
|
||||
"DocumentMove",
|
||||
|
|
@ -172,19 +169,10 @@ __all__ = [
|
|||
"FolderRead",
|
||||
"FolderReorder",
|
||||
"FolderUpdate",
|
||||
"GlobalImageGenConfigRead",
|
||||
"GlobalNewLLMConfigRead",
|
||||
# Vision LLM Config schemas
|
||||
"GlobalVisionLLMConfigRead",
|
||||
"GoogleDriveIndexRequest",
|
||||
"GoogleDriveIndexingOptions",
|
||||
# Base schemas
|
||||
"IDModel",
|
||||
# Image Generation Config schemas
|
||||
"ImageGenerationConfigCreate",
|
||||
"ImageGenerationConfigPublic",
|
||||
"ImageGenerationConfigRead",
|
||||
"ImageGenerationConfigUpdate",
|
||||
# Image Generation schemas
|
||||
"ImageGenerationCreate",
|
||||
"ImageGenerationListRead",
|
||||
|
|
@ -196,9 +184,6 @@ __all__ = [
|
|||
"InviteInfoResponse",
|
||||
"InviteRead",
|
||||
"InviteUpdate",
|
||||
# LLM Preferences schemas
|
||||
"LLMPreferencesRead",
|
||||
"LLMPreferencesUpdate",
|
||||
# Log schemas
|
||||
"LogBase",
|
||||
"LogCreate",
|
||||
|
|
@ -217,6 +202,16 @@ __all__ = [
|
|||
"MembershipRead",
|
||||
"MembershipReadWithUser",
|
||||
"MembershipUpdate",
|
||||
"ModelCreate",
|
||||
"ModelPreviewRead",
|
||||
"ModelProviderRead",
|
||||
"ModelRead",
|
||||
"ModelRolesRead",
|
||||
"ModelRolesUpdate",
|
||||
"ModelSelection",
|
||||
"ModelTestPreview",
|
||||
"ModelUpdate",
|
||||
"ModelsBulkUpdate",
|
||||
"NewChatMessageAppend",
|
||||
"NewChatMessageCreate",
|
||||
"NewChatMessageRead",
|
||||
|
|
@ -225,11 +220,6 @@ __all__ = [
|
|||
"NewChatThreadRead",
|
||||
"NewChatThreadUpdate",
|
||||
"NewChatThreadWithMessages",
|
||||
# NewLLMConfig schemas
|
||||
"NewLLMConfigCreate",
|
||||
"NewLLMConfigPublic",
|
||||
"NewLLMConfigRead",
|
||||
"NewLLMConfigUpdate",
|
||||
"PagePurchaseHistoryResponse",
|
||||
"PagePurchaseRead",
|
||||
"PaginatedResponse",
|
||||
|
|
@ -267,13 +257,10 @@ __all__ = [
|
|||
"UserRead",
|
||||
"UserSearchSpaceAccess",
|
||||
"UserUpdate",
|
||||
"VerifyConnectionResponse",
|
||||
# Video Presentation schemas
|
||||
"VideoPresentationBase",
|
||||
"VideoPresentationCreate",
|
||||
"VideoPresentationRead",
|
||||
"VideoPresentationUpdate",
|
||||
"VisionLLMConfigCreate",
|
||||
"VisionLLMConfigPublic",
|
||||
"VisionLLMConfigRead",
|
||||
"VisionLLMConfigUpdate",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,109 +1,10 @@
|
|||
"""
|
||||
Pydantic schemas for Image Generation configs and generation requests.
|
||||
"""Pydantic schemas for image generation requests/results."""
|
||||
|
||||
ImageGenerationConfig: CRUD schemas for user-created image gen model configs.
|
||||
ImageGeneration: Schemas for the actual image generation requests/results.
|
||||
GlobalImageGenConfigRead: Schema for admin-configured YAML configs.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from app.db import ImageGenProvider
|
||||
|
||||
# =============================================================================
|
||||
# ImageGenerationConfig CRUD Schemas
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ImageGenerationConfigBase(BaseModel):
|
||||
"""Base schema with fields for ImageGenerationConfig."""
|
||||
|
||||
name: str = Field(
|
||||
..., max_length=100, description="User-friendly name for the config"
|
||||
)
|
||||
description: str | None = Field(
|
||||
None, max_length=500, description="Optional description"
|
||||
)
|
||||
provider: ImageGenProvider = Field(
|
||||
...,
|
||||
description="Image generation provider (OpenAI, Azure, Google AI Studio, Vertex AI, Bedrock, Recraft, OpenRouter, Xinference, Nscale)",
|
||||
)
|
||||
custom_provider: str | None = Field(
|
||||
None, max_length=100, description="Custom provider name"
|
||||
)
|
||||
model_name: str = Field(
|
||||
..., max_length=100, description="Model name (e.g., dall-e-3, gpt-image-1)"
|
||||
)
|
||||
api_key: str = Field(..., description="API key for the provider")
|
||||
api_base: str | None = Field(
|
||||
None, max_length=500, description="Optional API base URL"
|
||||
)
|
||||
api_version: str | None = Field(
|
||||
None,
|
||||
max_length=50,
|
||||
description="Azure-specific API version (e.g., '2024-02-15-preview')",
|
||||
)
|
||||
litellm_params: dict[str, Any] | None = Field(
|
||||
default=None, description="Additional LiteLLM parameters"
|
||||
)
|
||||
|
||||
|
||||
class ImageGenerationConfigCreate(ImageGenerationConfigBase):
|
||||
"""Schema for creating a new ImageGenerationConfig."""
|
||||
|
||||
search_space_id: int = Field(
|
||||
..., description="Search space ID to associate the config with"
|
||||
)
|
||||
|
||||
|
||||
class ImageGenerationConfigUpdate(BaseModel):
|
||||
"""Schema for updating an existing ImageGenerationConfig. All fields optional."""
|
||||
|
||||
name: str | None = Field(None, max_length=100)
|
||||
description: str | None = Field(None, max_length=500)
|
||||
provider: ImageGenProvider | None = None
|
||||
custom_provider: str | None = Field(None, max_length=100)
|
||||
model_name: str | None = Field(None, max_length=100)
|
||||
api_key: str | None = None
|
||||
api_base: str | None = Field(None, max_length=500)
|
||||
api_version: str | None = Field(None, max_length=50)
|
||||
litellm_params: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ImageGenerationConfigRead(ImageGenerationConfigBase):
|
||||
"""Schema for reading an ImageGenerationConfig (includes id and timestamps)."""
|
||||
|
||||
id: int
|
||||
created_at: datetime
|
||||
search_space_id: int
|
||||
user_id: uuid.UUID
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ImageGenerationConfigPublic(BaseModel):
|
||||
"""Public schema that hides the API key (for list views)."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
description: str | None = None
|
||||
provider: ImageGenProvider
|
||||
custom_provider: str | None = None
|
||||
model_name: str
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
litellm_params: dict[str, Any] | None = None
|
||||
created_at: datetime
|
||||
search_space_id: int
|
||||
user_id: uuid.UUID
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# ImageGeneration (request/result) Schemas
|
||||
# =============================================================================
|
||||
|
|
@ -136,12 +37,12 @@ class ImageGenerationCreate(BaseModel):
|
|||
search_space_id: int = Field(
|
||||
..., description="Search space ID to associate the generation with"
|
||||
)
|
||||
image_generation_config_id: int | None = Field(
|
||||
image_gen_model_id: int | None = Field(
|
||||
None,
|
||||
description=(
|
||||
"Image generation config ID. "
|
||||
"0 = Auto mode (router), negative = global YAML config, positive = DB config. "
|
||||
"If not provided, uses the search space's image_generation_config_id preference."
|
||||
"Image generation model ID. "
|
||||
"0 = Auto mode, negative = GLOBAL model, positive = BYOK Model row. "
|
||||
"If not provided, uses the search space's image_gen_model_id preference."
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -157,7 +58,7 @@ class ImageGenerationRead(BaseModel):
|
|||
size: str | None = None
|
||||
style: str | None = None
|
||||
response_format: str | None = None
|
||||
image_generation_config_id: int | None = None
|
||||
image_gen_model_id: int | None = None
|
||||
response_data: dict[str, Any] | None = None
|
||||
error_message: str | None = None
|
||||
search_space_id: int
|
||||
|
|
@ -203,58 +104,3 @@ class ImageGenerationListRead(BaseModel):
|
|||
is_success=obj.response_data is not None,
|
||||
image_count=image_count,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Global Image Gen Config (from YAML)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class GlobalImageGenConfigRead(BaseModel):
|
||||
"""
|
||||
Schema for reading global image generation configs from YAML.
|
||||
Global configs have negative IDs. API key is hidden.
|
||||
ID 0 is reserved for Auto mode (LiteLLM Router load balancing).
|
||||
|
||||
The ``billing_tier`` field allows the frontend to show a Premium/Free
|
||||
badge and (more importantly) tells the backend whether to debit the
|
||||
user's premium credit pool when this config is used. ``"free"`` is
|
||||
the default for backward compatibility — admins must explicitly opt
|
||||
a global config into ``"premium"``.
|
||||
"""
|
||||
|
||||
id: int = Field(
|
||||
...,
|
||||
description="Config ID: 0 for Auto mode, negative for global configs",
|
||||
)
|
||||
name: str
|
||||
description: str | None = None
|
||||
provider: str
|
||||
custom_provider: str | None = None
|
||||
model_name: str
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
litellm_params: dict[str, Any] | None = None
|
||||
is_global: bool = True
|
||||
is_auto_mode: bool = False
|
||||
billing_tier: str = Field(
|
||||
default="free",
|
||||
description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).",
|
||||
)
|
||||
is_premium: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"Convenience boolean derived server-side from "
|
||||
"``billing_tier == 'premium'``. The new-chat model selector "
|
||||
"keys its Free/Premium badge off this field for parity with "
|
||||
"chat (`GlobalLLMConfigRead.is_premium`)."
|
||||
),
|
||||
)
|
||||
quota_reserve_micros: int | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Optional override for the reservation amount (in micro-USD) used when "
|
||||
"this image generation is premium. Falls back to "
|
||||
"QUOTA_DEFAULT_IMAGE_RESERVE_MICROS when omitted."
|
||||
),
|
||||
)
|
||||
|
|
|
|||
148
surfsense_backend/app/schemas/model_connections.py
Normal file
148
surfsense_backend/app/schemas/model_connections.py
Normal file
|
|
@ -0,0 +1,148 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from app.db import ConnectionScope, ModelSource
|
||||
|
||||
|
||||
class ModelRead(BaseModel):
|
||||
id: int
|
||||
connection_id: int
|
||||
model_id: str
|
||||
display_name: str | None = None
|
||||
source: ModelSource | str
|
||||
supports_chat: bool | None = None
|
||||
max_input_tokens: int | None = None
|
||||
supports_image_input: bool | None = None
|
||||
supports_tools: bool | None = None
|
||||
supports_image_generation: bool | None = None
|
||||
capabilities_override: dict[str, Any] = Field(default_factory=dict)
|
||||
enabled: bool
|
||||
billing_tier: str | None = None
|
||||
catalog: dict[str, Any] = Field(default_factory=dict)
|
||||
created_at: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ConnectionRead(BaseModel):
|
||||
id: int
|
||||
provider: str
|
||||
base_url: str | None = None
|
||||
api_key: str | None = None
|
||||
extra: dict[str, Any] = Field(default_factory=dict)
|
||||
scope: ConnectionScope | str
|
||||
search_space_id: int | None = None
|
||||
user_id: uuid.UUID | None = None
|
||||
enabled: bool
|
||||
has_api_key: bool
|
||||
models: list[ModelRead] = Field(default_factory=list)
|
||||
created_at: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ModelSelection(BaseModel):
|
||||
model_id: str = Field(..., max_length=255)
|
||||
display_name: str | None = Field(None, max_length=255)
|
||||
source: ModelSource | str = ModelSource.DISCOVERED
|
||||
supports_chat: bool | None = None
|
||||
max_input_tokens: int | None = None
|
||||
supports_image_input: bool | None = None
|
||||
supports_tools: bool | None = None
|
||||
supports_image_generation: bool | None = None
|
||||
enabled: bool = False
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ModelPreviewRead(BaseModel):
|
||||
model_id: str
|
||||
display_name: str | None = None
|
||||
source: ModelSource | str = ModelSource.DISCOVERED
|
||||
supports_chat: bool | None = None
|
||||
max_input_tokens: int | None = None
|
||||
supports_image_input: bool | None = None
|
||||
supports_tools: bool | None = None
|
||||
supports_image_generation: bool | None = None
|
||||
enabled: bool = False
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ConnectionCreate(BaseModel):
|
||||
provider: str = Field(..., max_length=100)
|
||||
base_url: str | None = Field(None, max_length=500)
|
||||
api_key: str | None = None
|
||||
extra: dict[str, Any] = Field(default_factory=dict)
|
||||
scope: ConnectionScope = ConnectionScope.SEARCH_SPACE
|
||||
search_space_id: int | None = None
|
||||
enabled: bool = True
|
||||
models: list[ModelSelection] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ModelTestPreview(ConnectionCreate):
|
||||
model_id: str = Field(..., max_length=255)
|
||||
|
||||
|
||||
class ConnectionUpdate(BaseModel):
|
||||
provider: str | None = Field(None, max_length=100)
|
||||
base_url: str | None = Field(None, max_length=500)
|
||||
api_key: str | None = None
|
||||
extra: dict[str, Any] | None = None
|
||||
enabled: bool | None = None
|
||||
|
||||
|
||||
class ModelCreate(BaseModel):
|
||||
"""Manually register a model id on a connection.
|
||||
|
||||
For providers without a usable ``/models`` endpoint (Perplexity, MiniMax,
|
||||
Azure deployments, etc.) or to pin a single model from a noisy provider.
|
||||
"""
|
||||
|
||||
model_id: str = Field(..., max_length=255)
|
||||
display_name: str | None = Field(None, max_length=255)
|
||||
|
||||
|
||||
class ModelUpdate(BaseModel):
|
||||
display_name: str | None = Field(None, max_length=255)
|
||||
enabled: bool | None = None
|
||||
supports_chat: bool | None = None
|
||||
max_input_tokens: int | None = None
|
||||
supports_image_input: bool | None = None
|
||||
supports_tools: bool | None = None
|
||||
supports_image_generation: bool | None = None
|
||||
capabilities_override: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ModelsBulkUpdate(BaseModel):
|
||||
model_ids: list[int] = Field(..., min_length=1, max_length=1000)
|
||||
enabled: bool
|
||||
|
||||
|
||||
class ModelProviderRead(BaseModel):
|
||||
provider: str
|
||||
transport: str
|
||||
discovery: str
|
||||
default_base_url: str | None = None
|
||||
base_url_required: bool
|
||||
auth_style: str
|
||||
local_only: bool = False
|
||||
|
||||
|
||||
class VerifyConnectionResponse(BaseModel):
|
||||
status: str
|
||||
ok: bool
|
||||
message: str = ""
|
||||
|
||||
|
||||
class ModelRolesRead(BaseModel):
|
||||
chat_model_id: int | None = 0
|
||||
vision_model_id: int | None = 0
|
||||
image_gen_model_id: int | None = 0
|
||||
|
||||
|
||||
class ModelRolesUpdate(BaseModel):
|
||||
chat_model_id: int | None = None
|
||||
vision_model_id: int | None = None
|
||||
image_gen_model_id: int | None = None
|
||||
|
|
@ -1,256 +0,0 @@
|
|||
"""
|
||||
Pydantic schemas for the NewLLMConfig API.
|
||||
|
||||
NewLLMConfig combines model settings with prompt configuration:
|
||||
- LLM provider, model, API key, etc.
|
||||
- Configurable system instructions
|
||||
- Citation toggle
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from app.db import LiteLLMProvider
|
||||
|
||||
|
||||
class NewLLMConfigBase(BaseModel):
|
||||
"""Base schema with common fields for NewLLMConfig."""
|
||||
|
||||
name: str = Field(
|
||||
..., max_length=100, description="User-friendly name for the configuration"
|
||||
)
|
||||
description: str | None = Field(
|
||||
None, max_length=500, description="Optional description"
|
||||
)
|
||||
|
||||
# Model Configuration
|
||||
provider: LiteLLMProvider = Field(..., description="LiteLLM provider type")
|
||||
custom_provider: str | None = Field(
|
||||
None, max_length=100, description="Custom provider name when provider is CUSTOM"
|
||||
)
|
||||
model_name: str = Field(
|
||||
..., max_length=100, description="Model name without provider prefix"
|
||||
)
|
||||
api_key: str = Field(..., description="API key for the provider")
|
||||
api_base: str | None = Field(
|
||||
None, max_length=500, description="Optional API base URL"
|
||||
)
|
||||
litellm_params: dict[str, Any] | None = Field(
|
||||
default=None, description="Additional LiteLLM parameters"
|
||||
)
|
||||
|
||||
# Prompt Configuration
|
||||
system_instructions: str = Field(
|
||||
default="",
|
||||
description="Custom system instructions. Empty string uses default SURFSENSE_SYSTEM_INSTRUCTIONS.",
|
||||
)
|
||||
use_default_system_instructions: bool = Field(
|
||||
default=True,
|
||||
description="Whether to use default instructions when system_instructions is empty",
|
||||
)
|
||||
citations_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Whether to include citation instructions in the system prompt",
|
||||
)
|
||||
|
||||
|
||||
class NewLLMConfigCreate(NewLLMConfigBase):
|
||||
"""Schema for creating a new NewLLMConfig."""
|
||||
|
||||
search_space_id: int = Field(
|
||||
..., description="Search space ID to associate the config with"
|
||||
)
|
||||
|
||||
|
||||
class NewLLMConfigUpdate(BaseModel):
|
||||
"""Schema for updating an existing NewLLMConfig. All fields are optional."""
|
||||
|
||||
name: str | None = Field(None, max_length=100)
|
||||
description: str | None = Field(None, max_length=500)
|
||||
|
||||
# Model Configuration
|
||||
provider: LiteLLMProvider | None = None
|
||||
custom_provider: str | None = Field(None, max_length=100)
|
||||
model_name: str | None = Field(None, max_length=100)
|
||||
api_key: str | None = None
|
||||
api_base: str | None = Field(None, max_length=500)
|
||||
litellm_params: dict[str, Any] | None = None
|
||||
|
||||
# Prompt Configuration
|
||||
system_instructions: str | None = None
|
||||
use_default_system_instructions: bool | None = None
|
||||
citations_enabled: bool | None = None
|
||||
|
||||
|
||||
class NewLLMConfigRead(NewLLMConfigBase):
|
||||
"""Schema for reading a NewLLMConfig (includes id and timestamps)."""
|
||||
|
||||
id: int
|
||||
created_at: datetime
|
||||
search_space_id: int
|
||||
user_id: uuid.UUID
|
||||
# Capability flag derived at the API boundary (no DB column). Default
|
||||
# True matches the conservative-allow stance — a BYOK row that the
|
||||
# route forgot to augment is not pre-judged. The streaming-task
|
||||
# safety net is the only place a False actually blocks a request.
|
||||
supports_image_input: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"Whether the BYOK chat config can accept image inputs. Derived "
|
||||
"at the route boundary from LiteLLM's authoritative model map "
|
||||
"(``litellm.supports_vision``) — there is no DB column. "
|
||||
"Default True is the conservative-allow stance for unknown / "
|
||||
"unmapped models."
|
||||
),
|
||||
)
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class NewLLMConfigPublic(BaseModel):
|
||||
"""
|
||||
Public schema for NewLLMConfig that hides the API key.
|
||||
Used when returning configs in list views or to users who shouldn't see keys.
|
||||
"""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
description: str | None = None
|
||||
|
||||
# Model Configuration (no api_key)
|
||||
provider: LiteLLMProvider
|
||||
custom_provider: str | None = None
|
||||
model_name: str
|
||||
api_base: str | None = None
|
||||
litellm_params: dict[str, Any] | None = None
|
||||
|
||||
# Prompt Configuration
|
||||
system_instructions: str
|
||||
use_default_system_instructions: bool
|
||||
citations_enabled: bool
|
||||
|
||||
created_at: datetime
|
||||
search_space_id: int
|
||||
user_id: uuid.UUID
|
||||
# Capability flag derived at the API boundary (see NewLLMConfigRead).
|
||||
supports_image_input: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"Whether the BYOK chat config can accept image inputs. Derived "
|
||||
"at the route boundary from LiteLLM's authoritative model map. "
|
||||
"Default True is the conservative-allow stance."
|
||||
),
|
||||
)
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class DefaultSystemInstructionsResponse(BaseModel):
|
||||
"""Response schema for getting default system instructions."""
|
||||
|
||||
default_system_instructions: str = Field(
|
||||
..., description="The default SURFSENSE_SYSTEM_INSTRUCTIONS template"
|
||||
)
|
||||
|
||||
|
||||
class GlobalNewLLMConfigRead(BaseModel):
|
||||
"""
|
||||
Schema for reading global LLM configs from YAML.
|
||||
Global configs have negative IDs and no search_space_id.
|
||||
API key is hidden for security.
|
||||
|
||||
ID 0 is reserved for Auto mode which uses LiteLLM Router for load balancing.
|
||||
"""
|
||||
|
||||
id: int = Field(
|
||||
...,
|
||||
description="Config ID: 0 for Auto mode, negative for global configs",
|
||||
)
|
||||
name: str
|
||||
description: str | None = None
|
||||
|
||||
# Model Configuration (no api_key)
|
||||
provider: str # String because YAML doesn't enforce enum, "AUTO" for Auto mode
|
||||
custom_provider: str | None = None
|
||||
model_name: str
|
||||
api_base: str | None = None
|
||||
litellm_params: dict[str, Any] | None = None
|
||||
|
||||
# Prompt Configuration
|
||||
system_instructions: str = ""
|
||||
use_default_system_instructions: bool = True
|
||||
citations_enabled: bool = True
|
||||
|
||||
is_global: bool = True # Always true for global configs
|
||||
is_auto_mode: bool = False # True only for Auto mode (ID 0)
|
||||
|
||||
billing_tier: str = "free"
|
||||
is_premium: bool = False
|
||||
anonymous_enabled: bool = False
|
||||
seo_enabled: bool = False
|
||||
seo_slug: str | None = None
|
||||
seo_title: str | None = None
|
||||
seo_description: str | None = None
|
||||
quota_reserve_tokens: int | None = None
|
||||
supports_image_input: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"Whether the model accepts image inputs (multimodal vision). "
|
||||
"Derived server-side: OpenRouter dynamic configs use "
|
||||
"``architecture.input_modalities``; YAML / BYOK use LiteLLM's "
|
||||
"authoritative model map (``litellm.supports_vision``). The "
|
||||
"new-chat selector hints with a 'No image' badge when this is "
|
||||
"False and there are pending image attachments. The streaming "
|
||||
"task fails fast only when LiteLLM *explicitly* marks a model "
|
||||
"as text-only — unknown / unmapped models default-allow."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# LLM Preferences Schemas (for role assignments)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class LLMPreferencesRead(BaseModel):
|
||||
"""Schema for reading LLM preferences (role assignments) for a search space."""
|
||||
|
||||
agent_llm_id: int | None = Field(
|
||||
None, description="ID of the LLM config to use for agent/chat tasks"
|
||||
)
|
||||
image_generation_config_id: int | None = Field(
|
||||
None, description="ID of the image generation config to use"
|
||||
)
|
||||
vision_llm_config_id: int | None = Field(
|
||||
None,
|
||||
description="ID of the vision LLM config to use for vision/screenshot analysis",
|
||||
)
|
||||
agent_llm: dict[str, Any] | None = Field(
|
||||
None, description="Full config for agent LLM"
|
||||
)
|
||||
image_generation_config: dict[str, Any] | None = Field(
|
||||
None, description="Full config for image generation"
|
||||
)
|
||||
vision_llm_config: dict[str, Any] | None = Field(
|
||||
None, description="Full config for vision LLM"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class LLMPreferencesUpdate(BaseModel):
|
||||
"""Schema for updating LLM preferences."""
|
||||
|
||||
agent_llm_id: int | None = Field(
|
||||
None, description="ID of the LLM config to use for agent/chat tasks"
|
||||
)
|
||||
image_generation_config_id: int | None = Field(
|
||||
None, description="ID of the image generation config to use"
|
||||
)
|
||||
vision_llm_config_id: int | None = Field(
|
||||
None,
|
||||
description="ID of the vision LLM config to use for vision/screenshot analysis",
|
||||
)
|
||||
|
|
@ -1,116 +0,0 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from app.db import VisionProvider
|
||||
|
||||
|
||||
class VisionLLMConfigBase(BaseModel):
|
||||
name: str = Field(..., max_length=100)
|
||||
description: str | None = Field(None, max_length=500)
|
||||
provider: VisionProvider = Field(...)
|
||||
custom_provider: str | None = Field(None, max_length=100)
|
||||
model_name: str = Field(..., max_length=100)
|
||||
api_key: str = Field(...)
|
||||
api_base: str | None = Field(None, max_length=500)
|
||||
api_version: str | None = Field(None, max_length=50)
|
||||
litellm_params: dict[str, Any] | None = Field(default=None)
|
||||
|
||||
|
||||
class VisionLLMConfigCreate(VisionLLMConfigBase):
|
||||
search_space_id: int = Field(...)
|
||||
|
||||
|
||||
class VisionLLMConfigUpdate(BaseModel):
|
||||
name: str | None = Field(None, max_length=100)
|
||||
description: str | None = Field(None, max_length=500)
|
||||
provider: VisionProvider | None = None
|
||||
custom_provider: str | None = Field(None, max_length=100)
|
||||
model_name: str | None = Field(None, max_length=100)
|
||||
api_key: str | None = None
|
||||
api_base: str | None = Field(None, max_length=500)
|
||||
api_version: str | None = Field(None, max_length=50)
|
||||
litellm_params: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class VisionLLMConfigRead(VisionLLMConfigBase):
|
||||
id: int
|
||||
created_at: datetime
|
||||
search_space_id: int
|
||||
user_id: uuid.UUID
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class VisionLLMConfigPublic(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: str | None = None
|
||||
provider: VisionProvider
|
||||
custom_provider: str | None = None
|
||||
model_name: str
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
litellm_params: dict[str, Any] | None = None
|
||||
created_at: datetime
|
||||
search_space_id: int
|
||||
user_id: uuid.UUID
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class GlobalVisionLLMConfigRead(BaseModel):
|
||||
"""Schema for reading global vision LLM configs from YAML.
|
||||
|
||||
The ``billing_tier`` field allows the frontend to show a Premium/Free
|
||||
badge and (more importantly) tells the backend whether to debit the
|
||||
user's premium credit pool when this config is used. ``"free"`` is
|
||||
the default for backward compatibility — admins must explicitly opt
|
||||
a global config into ``"premium"``.
|
||||
"""
|
||||
|
||||
id: int = Field(...)
|
||||
name: str
|
||||
description: str | None = None
|
||||
provider: str
|
||||
custom_provider: str | None = None
|
||||
model_name: str
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
litellm_params: dict[str, Any] | None = None
|
||||
is_global: bool = True
|
||||
is_auto_mode: bool = False
|
||||
billing_tier: str = Field(
|
||||
default="free",
|
||||
description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).",
|
||||
)
|
||||
is_premium: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"Convenience boolean derived server-side from "
|
||||
"``billing_tier == 'premium'``. The new-chat model selector "
|
||||
"keys its Free/Premium badge off this field for parity with "
|
||||
"chat (`GlobalLLMConfigRead.is_premium`)."
|
||||
),
|
||||
)
|
||||
quota_reserve_tokens: int | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Optional override for the per-call reservation in *tokens* — "
|
||||
"converted to micro-USD via the model's input/output prices at "
|
||||
"reservation time. Falls back to QUOTA_DEFAULT_RESERVE_TOKENS."
|
||||
),
|
||||
)
|
||||
input_cost_per_token: float | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Optional input price in USD/token. Used by pricing_registration to "
|
||||
"register custom Azure / OpenRouter aliases with LiteLLM at startup."
|
||||
),
|
||||
)
|
||||
output_cost_per_token: float | None = Field(
|
||||
default=None,
|
||||
description="Optional output price in USD/token. Pair with input_cost_per_token.",
|
||||
)
|
||||
|
|
@ -1,13 +1,13 @@
|
|||
"""Resolve and persist Auto (Fastest) model pins per chat thread.
|
||||
"""Resolve and persist Auto model pins per chat thread.
|
||||
|
||||
Auto (Fastest) is represented by ``agent_llm_id == 0``. For chat threads we
|
||||
resolve that virtual mode to one concrete global LLM config exactly once and
|
||||
Auto is represented by ``chat_model_id == 0``. For chat threads we
|
||||
resolve that virtual mode to one concrete global model exactly once and
|
||||
persist the chosen config id on ``new_chat_threads.pinned_llm_config_id`` so
|
||||
subsequent turns are stable.
|
||||
|
||||
Single-writer invariant: this module is the only writer of
|
||||
``NewChatThread.pinned_llm_config_id`` (aside from the bulk clear in
|
||||
``search_spaces_routes`` when a search space's ``agent_llm_id`` changes).
|
||||
``model_connections_routes`` when a search space's ``chat_model_id`` changes).
|
||||
Therefore a non-NULL value unambiguously means "this thread has an
|
||||
Auto-resolved pin"; no separate source/policy column is needed.
|
||||
"""
|
||||
|
|
@ -21,26 +21,35 @@ import time
|
|||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
import redis
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.config import config
|
||||
from app.db import NewChatThread
|
||||
from app.db import Connection, Model, NewChatThread
|
||||
from app.services.model_capabilities import has_capability
|
||||
from app.services.quality_score import _QUALITY_TOP_K
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
AUTO_FASTEST_ID = 0
|
||||
AUTO_FASTEST_MODE = "auto_fastest"
|
||||
AUTO_MODE_ID = 0
|
||||
# Stable internal hash namespace for deterministic per-thread selection.
|
||||
# Do not rename: changing this rebalances Auto's model choice for new pins.
|
||||
AUTO_PIN_HASH_NAMESPACE = "auto_fastest"
|
||||
_RUNTIME_COOLDOWN_SECONDS = 600
|
||||
_HEALTHY_TTL_SECONDS = 45
|
||||
_RUNTIME_COOLDOWN_REDIS_KEY_PREFIX = "auto:cooldown:llm:"
|
||||
_REDIS_TIMEOUT_SECONDS = 0.2
|
||||
|
||||
# In-memory runtime cooldown map for configs that recently hard-failed at
|
||||
# provider runtime (e.g. OpenRouter 429 on a pinned free model). This keeps
|
||||
# the same unhealthy config from being reselected immediately during repair.
|
||||
_runtime_cooldown_until: dict[int, float] = {}
|
||||
_runtime_cooldown_lock = threading.Lock()
|
||||
_runtime_cooldown_redis: redis.Redis | None = None
|
||||
_runtime_cooldown_redis_lock = threading.Lock()
|
||||
|
||||
# Short-TTL "recently healthy" cache for configs that just passed a runtime
|
||||
# preflight ping. Lets back-to-back turns on the same model skip the probe
|
||||
|
|
@ -61,11 +70,15 @@ def _is_usable_global_config(cfg: dict) -> bool:
|
|||
return bool(
|
||||
cfg.get("id") is not None
|
||||
and cfg.get("model_name")
|
||||
and cfg.get("provider")
|
||||
and (cfg.get("provider") or cfg.get("litellm_provider"))
|
||||
and cfg.get("api_key")
|
||||
)
|
||||
|
||||
|
||||
def _has_capability(model: dict | Model, capability: str) -> bool:
|
||||
return has_capability(model, capability)
|
||||
|
||||
|
||||
def _prune_runtime_cooldowns(now_ts: float | None = None) -> None:
|
||||
now = time.time() if now_ts is None else now_ts
|
||||
stale = [cid for cid, until in _runtime_cooldown_until.items() if until <= now]
|
||||
|
|
@ -79,6 +92,81 @@ def _is_runtime_cooled_down(config_id: int) -> bool:
|
|||
return config_id in _runtime_cooldown_until
|
||||
|
||||
|
||||
def _runtime_cooldown_redis_key(config_id: int) -> str:
|
||||
return f"{_RUNTIME_COOLDOWN_REDIS_KEY_PREFIX}{int(config_id)}"
|
||||
|
||||
|
||||
def _get_runtime_cooldown_redis() -> redis.Redis:
|
||||
global _runtime_cooldown_redis
|
||||
if _runtime_cooldown_redis is None:
|
||||
with _runtime_cooldown_redis_lock:
|
||||
if _runtime_cooldown_redis is None:
|
||||
_runtime_cooldown_redis = redis.from_url(
|
||||
config.REDIS_APP_URL,
|
||||
decode_responses=True,
|
||||
socket_connect_timeout=_REDIS_TIMEOUT_SECONDS,
|
||||
socket_timeout=_REDIS_TIMEOUT_SECONDS,
|
||||
)
|
||||
return _runtime_cooldown_redis
|
||||
|
||||
|
||||
def _mark_shared_runtime_cooldown(
|
||||
config_id: int,
|
||||
*,
|
||||
reason: str,
|
||||
cooldown_seconds: int,
|
||||
) -> None:
|
||||
try:
|
||||
_get_runtime_cooldown_redis().set(
|
||||
_runtime_cooldown_redis_key(config_id),
|
||||
reason,
|
||||
ex=int(cooldown_seconds),
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"auto_pin_runtime_cooldown_redis_write_failed config_id=%s",
|
||||
config_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
def _shared_runtime_cooled_down_ids(config_ids: list[int]) -> set[int]:
|
||||
unique_ids = list(dict.fromkeys(int(cid) for cid in config_ids))
|
||||
if not unique_ids:
|
||||
return set()
|
||||
try:
|
||||
values = _get_runtime_cooldown_redis().mget(
|
||||
[_runtime_cooldown_redis_key(cid) for cid in unique_ids]
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"auto_pin_runtime_cooldown_redis_read_failed count=%s",
|
||||
len(unique_ids),
|
||||
exc_info=True,
|
||||
)
|
||||
return set()
|
||||
return {
|
||||
cid for cid, value in zip(unique_ids, values, strict=False) if value is not None
|
||||
}
|
||||
|
||||
|
||||
def _clear_shared_runtime_cooldown(config_id: int | None = None) -> None:
|
||||
try:
|
||||
client = _get_runtime_cooldown_redis()
|
||||
if config_id is not None:
|
||||
client.delete(_runtime_cooldown_redis_key(config_id))
|
||||
return
|
||||
keys = list(client.scan_iter(f"{_RUNTIME_COOLDOWN_REDIS_KEY_PREFIX}*"))
|
||||
if keys:
|
||||
client.delete(*keys)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"auto_pin_runtime_cooldown_redis_clear_failed config_id=%s",
|
||||
config_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
def mark_runtime_cooldown(
|
||||
config_id: int,
|
||||
*,
|
||||
|
|
@ -97,6 +185,11 @@ def mark_runtime_cooldown(
|
|||
with _runtime_cooldown_lock:
|
||||
_runtime_cooldown_until[int(config_id)] = until
|
||||
_prune_runtime_cooldowns()
|
||||
_mark_shared_runtime_cooldown(
|
||||
int(config_id),
|
||||
reason=reason,
|
||||
cooldown_seconds=int(cooldown_seconds),
|
||||
)
|
||||
# A cooled cfg can never be "recently healthy"; drop any stale credit so
|
||||
# the next turn that resolves to it (after cooldown) re-runs preflight.
|
||||
clear_healthy(int(config_id))
|
||||
|
|
@ -113,8 +206,9 @@ def clear_runtime_cooldown(config_id: int | None = None) -> None:
|
|||
with _runtime_cooldown_lock:
|
||||
if config_id is None:
|
||||
_runtime_cooldown_until.clear()
|
||||
return
|
||||
_runtime_cooldown_until.pop(int(config_id), None)
|
||||
else:
|
||||
_runtime_cooldown_until.pop(int(config_id), None)
|
||||
_clear_shared_runtime_cooldown(config_id)
|
||||
|
||||
|
||||
def _prune_healthy(now_ts: float | None = None) -> None:
|
||||
|
|
@ -186,15 +280,20 @@ def _cfg_supports_image_input(cfg: dict) -> bool:
|
|||
else None
|
||||
)
|
||||
return derive_supports_image_input(
|
||||
provider=cfg.get("provider"),
|
||||
provider=cfg.get("provider") or cfg.get("litellm_provider"),
|
||||
model_name=cfg.get("model_name"),
|
||||
base_model=base_model,
|
||||
custom_provider=cfg.get("custom_provider"),
|
||||
)
|
||||
|
||||
|
||||
def _global_candidates(*, requires_image_input: bool = False) -> list[dict]:
|
||||
"""Return Auto-eligible global cfgs.
|
||||
def _global_candidates(
|
||||
*,
|
||||
capability: str = "chat",
|
||||
requires_image_input: bool = False,
|
||||
shared_cooled_down_ids: set[int] | None = None,
|
||||
) -> list[dict]:
|
||||
"""Return Auto-eligible global virtual models.
|
||||
|
||||
Drops cfgs flagged ``health_gated`` (best non-null OpenRouter uptime
|
||||
below ``_HEALTH_GATE_UPTIME_PCT``) so chronically broken providers
|
||||
|
|
@ -205,30 +304,167 @@ def _global_candidates(*, requires_image_input: bool = False) -> list[dict]:
|
|||
filters out configs whose ``supports_image_input`` resolves to False
|
||||
so a text-only deployment can't be pinned for an image request.
|
||||
"""
|
||||
candidates = [
|
||||
cfg
|
||||
connection_by_id = {
|
||||
int(conn.get("id")): conn
|
||||
for conn in config.GLOBAL_CONNECTIONS
|
||||
if conn.get("id") is not None
|
||||
}
|
||||
config_by_model_name = {
|
||||
cfg.get("model_name"): cfg
|
||||
for cfg in config.GLOBAL_LLM_CONFIGS
|
||||
if _is_usable_global_config(cfg)
|
||||
and not cfg.get("health_gated")
|
||||
and not _is_runtime_cooled_down(int(cfg.get("id", 0)))
|
||||
and (not requires_image_input or _cfg_supports_image_input(cfg))
|
||||
]
|
||||
}
|
||||
candidates: list[dict] = []
|
||||
shared_cooled_down_ids = shared_cooled_down_ids or set()
|
||||
for model in config.GLOBAL_MODELS:
|
||||
model_id = int(model.get("id", 0))
|
||||
if (
|
||||
model_id >= 0
|
||||
or _is_runtime_cooled_down(model_id)
|
||||
or model_id in shared_cooled_down_ids
|
||||
):
|
||||
continue
|
||||
if not _has_capability(model, capability):
|
||||
continue
|
||||
cfg = config_by_model_name.get(model.get("model_id")) or {}
|
||||
if cfg.get("health_gated"):
|
||||
continue
|
||||
if requires_image_input and not _has_capability(model, "vision"):
|
||||
continue
|
||||
if requires_image_input and cfg and not _cfg_supports_image_input(cfg):
|
||||
continue
|
||||
connection = connection_by_id.get(int(model.get("connection_id", 0)))
|
||||
if not connection:
|
||||
continue
|
||||
catalog = model.get("catalog") or {}
|
||||
candidates.append(
|
||||
{
|
||||
"id": model_id,
|
||||
"model_id": model.get("model_id"),
|
||||
"source": "global",
|
||||
"connection": connection,
|
||||
"supports_chat": model.get("supports_chat"),
|
||||
"supports_image_input": model.get("supports_image_input"),
|
||||
"supports_tools": model.get("supports_tools"),
|
||||
"supports_image_generation": model.get("supports_image_generation"),
|
||||
"capabilities_override": model.get("capabilities_override") or {},
|
||||
"billing_tier": model.get("billing_tier", "free"),
|
||||
"provider": connection.get("provider"),
|
||||
"model_name": model.get("model_id"),
|
||||
"auto_pin_tier": catalog.get("auto_pin_tier")
|
||||
or cfg.get("auto_pin_tier")
|
||||
or "A",
|
||||
"quality_score": catalog.get("quality_score")
|
||||
or cfg.get("quality_score")
|
||||
or cfg.get("quality_score_static")
|
||||
or 50,
|
||||
}
|
||||
)
|
||||
return sorted(candidates, key=lambda c: int(c.get("id", 0)))
|
||||
|
||||
|
||||
async def _db_candidates(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
search_space_id: int,
|
||||
user_id: str | UUID | None,
|
||||
capability: str,
|
||||
requires_image_input: bool = False,
|
||||
) -> list[dict]:
|
||||
parsed_user_id = _to_uuid(user_id)
|
||||
stmt = (
|
||||
select(Model)
|
||||
.options(selectinload(Model.connection))
|
||||
.join(Connection, Model.connection_id == Connection.id)
|
||||
.where(Model.enabled.is_(True), Connection.enabled.is_(True))
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
models = result.scalars().all()
|
||||
shared_cooled_down_ids = _shared_runtime_cooled_down_ids(
|
||||
[int(model.id) for model in models]
|
||||
)
|
||||
candidates: list[dict] = []
|
||||
for model in models:
|
||||
conn = model.connection
|
||||
if not conn:
|
||||
continue
|
||||
if conn.search_space_id is not None and conn.search_space_id != search_space_id:
|
||||
continue
|
||||
if (
|
||||
conn.user_id is not None
|
||||
and parsed_user_id is not None
|
||||
and conn.user_id != parsed_user_id
|
||||
):
|
||||
continue
|
||||
if conn.user_id is not None and parsed_user_id is None:
|
||||
continue
|
||||
if not _has_capability(model, capability):
|
||||
continue
|
||||
if requires_image_input and not _has_capability(model, "vision"):
|
||||
continue
|
||||
model_id = int(model.id)
|
||||
if _is_runtime_cooled_down(model_id) or model_id in shared_cooled_down_ids:
|
||||
continue
|
||||
catalog = model.catalog or {}
|
||||
candidates.append(
|
||||
{
|
||||
"id": model_id,
|
||||
"model_id": model.model_id,
|
||||
"source": "db",
|
||||
"connection": conn,
|
||||
"supports_chat": model.supports_chat,
|
||||
"supports_image_input": model.supports_image_input,
|
||||
"supports_tools": model.supports_tools,
|
||||
"supports_image_generation": model.supports_image_generation,
|
||||
"capabilities_override": model.capabilities_override or {},
|
||||
"billing_tier": "byok",
|
||||
"provider": conn.provider,
|
||||
"model_name": model.model_id,
|
||||
"auto_pin_tier": catalog.get("auto_pin_tier") or "BYOK",
|
||||
"quality_score": catalog.get("quality_score") or 75,
|
||||
}
|
||||
)
|
||||
return sorted(candidates, key=lambda c: int(c.get("id", 0)))
|
||||
|
||||
|
||||
async def auto_model_candidates(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
search_space_id: int,
|
||||
user_id: str | UUID | None,
|
||||
capability: str,
|
||||
requires_image_input: bool = False,
|
||||
exclude_model_ids: set[int] | None = None,
|
||||
) -> list[dict]:
|
||||
excluded_ids = {int(mid) for mid in (exclude_model_ids or set())}
|
||||
global_ids = [
|
||||
int(model.get("id", 0))
|
||||
for model in config.GLOBAL_MODELS
|
||||
if int(model.get("id", 0)) < 0
|
||||
]
|
||||
shared_global_cooled_down_ids = _shared_runtime_cooled_down_ids(global_ids)
|
||||
db_candidates = await _db_candidates(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
capability=capability,
|
||||
requires_image_input=requires_image_input,
|
||||
)
|
||||
candidates = [
|
||||
*_global_candidates(
|
||||
capability=capability,
|
||||
requires_image_input=requires_image_input,
|
||||
shared_cooled_down_ids=shared_global_cooled_down_ids,
|
||||
),
|
||||
*db_candidates,
|
||||
]
|
||||
return [c for c in candidates if int(c.get("id", 0)) not in excluded_ids]
|
||||
|
||||
|
||||
def _tier_of(cfg: dict) -> str:
|
||||
return str(cfg.get("billing_tier", "free")).lower()
|
||||
|
||||
|
||||
def _is_preferred_premium_auto_config(cfg: dict) -> bool:
|
||||
"""Return True for the operator-preferred premium Auto model."""
|
||||
return (
|
||||
_tier_of(cfg) == "premium"
|
||||
and str(cfg.get("provider", "")).upper() == "AZURE_OPENAI"
|
||||
and str(cfg.get("model_name", "")).lower() == "gpt-5.4"
|
||||
)
|
||||
|
||||
|
||||
def _select_pin(eligible: list[dict], thread_id: int) -> tuple[dict, int]:
|
||||
"""Pick a config with quality-first ranking + deterministic spread.
|
||||
|
||||
|
|
@ -246,11 +482,16 @@ def _select_pin(eligible: list[dict], thread_id: int) -> tuple[dict, int]:
|
|||
pool = tier_a if tier_a else eligible
|
||||
pool = sorted(pool, key=lambda c: -int(c.get("quality_score") or 0))
|
||||
top_k = pool[:_QUALITY_TOP_K]
|
||||
digest = hashlib.sha256(f"{AUTO_FASTEST_MODE}:{thread_id}".encode()).digest()
|
||||
digest = hashlib.sha256(f"{AUTO_PIN_HASH_NAMESPACE}:{thread_id}".encode()).digest()
|
||||
idx = int.from_bytes(digest[:8], "big") % len(top_k)
|
||||
return top_k[idx], len(top_k)
|
||||
|
||||
|
||||
def choose_auto_model_candidate(candidates: list[dict], seed_id: int) -> dict:
|
||||
selected, _ = _select_pin(candidates, seed_id)
|
||||
return selected
|
||||
|
||||
|
||||
def _to_uuid(user_id: str | UUID | None) -> UUID | None:
|
||||
if user_id is None:
|
||||
return None
|
||||
|
|
@ -283,7 +524,7 @@ async def resolve_or_get_pinned_llm_config_id(
|
|||
exclude_config_ids: set[int] | None = None,
|
||||
requires_image_input: bool = False,
|
||||
) -> AutoPinResolution:
|
||||
"""Resolve Auto (Fastest) to one concrete config id and persist the pin.
|
||||
"""Resolve Auto to one concrete config id and persist the pin.
|
||||
|
||||
For non-auto selections, this function clears any existing pin and returns
|
||||
the selected id as-is.
|
||||
|
|
@ -315,7 +556,7 @@ async def resolve_or_get_pinned_llm_config_id(
|
|||
)
|
||||
|
||||
# Explicit model selected: clear any stale pin.
|
||||
if selected_llm_config_id != AUTO_FASTEST_ID:
|
||||
if selected_llm_config_id != AUTO_MODE_ID:
|
||||
if thread.pinned_llm_config_id is not None:
|
||||
thread.pinned_llm_config_id = None
|
||||
await session.commit()
|
||||
|
|
@ -326,20 +567,21 @@ async def resolve_or_get_pinned_llm_config_id(
|
|||
)
|
||||
|
||||
excluded_ids = {int(cid) for cid in (exclude_config_ids or set())}
|
||||
candidates = [
|
||||
c
|
||||
for c in _global_candidates(requires_image_input=requires_image_input)
|
||||
if int(c.get("id", 0)) not in excluded_ids
|
||||
]
|
||||
candidates = await auto_model_candidates(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
capability="chat",
|
||||
requires_image_input=requires_image_input,
|
||||
exclude_model_ids=excluded_ids,
|
||||
)
|
||||
if not candidates:
|
||||
if requires_image_input:
|
||||
# Distinguish the "no vision-capable cfg" case from generic
|
||||
# "no usable cfg" so the streaming task can map this to the
|
||||
# MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT SSE error.
|
||||
raise ValueError(
|
||||
"No vision-capable global LLM configs are available for Auto mode"
|
||||
)
|
||||
raise ValueError("No usable global LLM configs are available for Auto mode")
|
||||
raise ValueError("No vision-capable LLM models are available for Auto mode")
|
||||
raise ValueError("No usable LLM models are available for Auto mode")
|
||||
candidate_by_id = {int(c["id"]): c for c in candidates}
|
||||
|
||||
# Reuse an existing valid pin without re-checking current quota (no silent
|
||||
|
|
@ -379,24 +621,13 @@ async def resolve_or_get_pinned_llm_config_id(
|
|||
# log that explicitly so operators can correlate the re-pin with
|
||||
# the user's image attachment instead of suspecting a cooldown.
|
||||
if requires_image_input:
|
||||
try:
|
||||
pinned_global = next(
|
||||
c
|
||||
for c in config.GLOBAL_LLM_CONFIGS
|
||||
if int(c.get("id", 0)) == int(pinned_id)
|
||||
)
|
||||
except StopIteration:
|
||||
pinned_global = None
|
||||
if pinned_global is not None and not _cfg_supports_image_input(
|
||||
pinned_global
|
||||
):
|
||||
logger.info(
|
||||
"auto_pin_repinned_for_image thread_id=%s search_space_id=%s "
|
||||
"previous_config_id=%s",
|
||||
thread_id,
|
||||
search_space_id,
|
||||
pinned_id,
|
||||
)
|
||||
logger.info(
|
||||
"auto_pin_repinned_for_image thread_id=%s search_space_id=%s "
|
||||
"previous_config_id=%s",
|
||||
thread_id,
|
||||
search_space_id,
|
||||
pinned_id,
|
||||
)
|
||||
logger.info(
|
||||
"auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s",
|
||||
thread_id,
|
||||
|
|
@ -407,12 +638,10 @@ async def resolve_or_get_pinned_llm_config_id(
|
|||
premium_eligible = (
|
||||
False if force_repin_free else await _is_premium_eligible(session, user_id)
|
||||
)
|
||||
byok_candidates = [c for c in candidates if _tier_of(c) == "byok"]
|
||||
if premium_eligible:
|
||||
premium_candidates = [c for c in candidates if _tier_of(c) == "premium"]
|
||||
preferred_premium = [
|
||||
c for c in premium_candidates if _is_preferred_premium_auto_config(c)
|
||||
]
|
||||
eligible = preferred_premium or premium_candidates
|
||||
eligible = premium_candidates or byok_candidates
|
||||
else:
|
||||
eligible = [c for c in candidates if _tier_of(c) != "premium"]
|
||||
|
||||
|
|
|
|||
|
|
@ -445,15 +445,15 @@ async def _resolve_agent_billing_for_search_space(
|
|||
thread_id: int | None = None,
|
||||
) -> tuple[UUID, str, str]:
|
||||
"""Resolve ``(owner_user_id, billing_tier, base_model)`` for the search-space
|
||||
agent LLM.
|
||||
chat model.
|
||||
|
||||
Used by Celery tasks (podcast generation, video presentation) to bill the
|
||||
search-space owner's premium credit pool when the agent LLM is premium.
|
||||
search-space owner's premium credit pool when the chat model is premium.
|
||||
|
||||
Resolution rules mirror chat at ``stream_new_chat.py:2294-2351``:
|
||||
Resolution rules mirror the chat model role resolver:
|
||||
|
||||
- Search space not found / no ``agent_llm_id``: raise ``ValueError``.
|
||||
- **Auto mode** (``id == AUTO_FASTEST_ID == 0``):
|
||||
- Search space not found / no ``chat_model_id``: raise ``ValueError``.
|
||||
- **Auto mode** (``id == AUTO_MODE_ID == 0``):
|
||||
* ``thread_id`` is set: delegate to
|
||||
``resolve_or_get_pinned_llm_config_id`` (the same call chat uses) and
|
||||
recurse into the resolved id. Reuses chat's existing pin if present
|
||||
|
|
@ -469,9 +469,8 @@ async def _resolve_agent_billing_for_search_space(
|
|||
(defaults to ``"free"`` via ``app/config/__init__.py:52`` setdefault),
|
||||
``base_model = litellm_params.get("base_model") or model_name`` —
|
||||
NOT provider-prefixed, matching chat's cost-map lookup convention.
|
||||
- **Positive id** (user BYOK ``NewLLMConfig``): always free (matches
|
||||
``AgentConfig.from_new_llm_config`` which hard-codes ``billing_tier="free"``);
|
||||
``base_model`` from ``litellm_params`` or ``model_name``.
|
||||
- **Positive id** (user BYOK ``Model``): always free; ``base_model`` from
|
||||
the model catalog override or the upstream ``model_id``.
|
||||
|
||||
Note on imports: ``llm_service``, ``auto_model_pin_service``, and
|
||||
``llm_router_service`` are imported lazily inside the function body to
|
||||
|
|
@ -480,8 +479,9 @@ async def _resolve_agent_billing_for_search_space(
|
|||
``billable_calls.py``'s module load path.
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.db import NewLLMConfig, SearchSpace
|
||||
from app.db import Model, SearchSpace
|
||||
|
||||
result = await session.execute(
|
||||
select(SearchSpace).where(SearchSpace.id == search_space_id)
|
||||
|
|
@ -490,20 +490,20 @@ async def _resolve_agent_billing_for_search_space(
|
|||
if search_space is None:
|
||||
raise ValueError(f"Search space {search_space_id} not found")
|
||||
|
||||
agent_llm_id = search_space.agent_llm_id
|
||||
if agent_llm_id is None:
|
||||
chat_model_id = search_space.chat_model_id
|
||||
if chat_model_id is None:
|
||||
raise ValueError(
|
||||
f"Search space {search_space_id} has no agent_llm_id configured"
|
||||
f"Search space {search_space_id} has no chat_model_id configured"
|
||||
)
|
||||
|
||||
owner_user_id: UUID = search_space.user_id
|
||||
|
||||
from app.services.auto_model_pin_service import (
|
||||
AUTO_FASTEST_ID,
|
||||
AUTO_MODE_ID,
|
||||
resolve_or_get_pinned_llm_config_id,
|
||||
)
|
||||
|
||||
if agent_llm_id == AUTO_FASTEST_ID:
|
||||
if chat_model_id == AUTO_MODE_ID:
|
||||
if thread_id is None:
|
||||
return owner_user_id, "free", "auto"
|
||||
try:
|
||||
|
|
@ -512,7 +512,7 @@ async def _resolve_agent_billing_for_search_space(
|
|||
thread_id=thread_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=str(owner_user_id),
|
||||
selected_llm_config_id=AUTO_FASTEST_ID,
|
||||
selected_llm_config_id=AUTO_MODE_ID,
|
||||
)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
|
|
@ -523,28 +523,35 @@ async def _resolve_agent_billing_for_search_space(
|
|||
exc_info=True,
|
||||
)
|
||||
return owner_user_id, "free", "auto"
|
||||
agent_llm_id = resolution.resolved_llm_config_id
|
||||
chat_model_id = resolution.resolved_llm_config_id
|
||||
|
||||
if agent_llm_id < 0:
|
||||
if chat_model_id < 0:
|
||||
from app.services.llm_service import get_global_llm_config
|
||||
|
||||
cfg = get_global_llm_config(agent_llm_id) or {}
|
||||
cfg = get_global_llm_config(chat_model_id) or {}
|
||||
billing_tier = str(cfg.get("billing_tier", "free")).lower()
|
||||
litellm_params = cfg.get("litellm_params") or {}
|
||||
base_model = litellm_params.get("base_model") or cfg.get("model_name") or ""
|
||||
return owner_user_id, billing_tier, base_model
|
||||
|
||||
nlc_result = await session.execute(
|
||||
select(NewLLMConfig).where(
|
||||
NewLLMConfig.id == agent_llm_id,
|
||||
NewLLMConfig.search_space_id == search_space_id,
|
||||
)
|
||||
model_result = await session.execute(
|
||||
select(Model)
|
||||
.options(selectinload(Model.connection))
|
||||
.where(Model.id == chat_model_id, Model.enabled.is_(True))
|
||||
)
|
||||
nlc = nlc_result.scalars().first()
|
||||
model = model_result.scalars().first()
|
||||
base_model = ""
|
||||
if nlc is not None:
|
||||
litellm_params = nlc.litellm_params or {}
|
||||
base_model = litellm_params.get("base_model") or nlc.model_name or ""
|
||||
if (
|
||||
model is not None
|
||||
and model.connection is not None
|
||||
and model.connection.enabled
|
||||
and (
|
||||
model.connection.search_space_id in (None, search_space_id)
|
||||
and model.connection.user_id in (None, owner_user_id)
|
||||
)
|
||||
):
|
||||
catalog = model.catalog or {}
|
||||
base_model = catalog.get("base_model") or model.model_id or ""
|
||||
return owner_user_id, "free", base_model
|
||||
|
||||
|
||||
|
|
|
|||
128
surfsense_backend/app/services/global_model_catalog.py
Normal file
128
surfsense_backend/app/services/global_model_catalog.py
Normal file
|
|
@ -0,0 +1,128 @@
|
|||
"""Materialize server-owned GLOBAL YAML configs as virtual connections/models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.services.model_resolver import native_connection_from_config
|
||||
|
||||
|
||||
def _base_model(config: dict[str, Any]) -> str | None:
|
||||
litellm_params = config.get("litellm_params") or {}
|
||||
if isinstance(litellm_params, dict):
|
||||
return litellm_params.get("base_model")
|
||||
return None
|
||||
|
||||
|
||||
def _connection_key(conn: dict[str, Any]) -> tuple[Any, ...]:
|
||||
# Deliberately includes api_key because two operator-owned credentials for
|
||||
# the same provider/base can have different quota/rate limits upstream.
|
||||
return (
|
||||
conn.get("provider"),
|
||||
conn.get("base_url"),
|
||||
conn.get("api_key"),
|
||||
_freeze(conn.get("extra") or {}),
|
||||
)
|
||||
|
||||
|
||||
def _freeze(value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
return tuple(sorted((key, _freeze(val)) for key, val in value.items()))
|
||||
if isinstance(value, list):
|
||||
return tuple(_freeze(item) for item in value)
|
||||
return value
|
||||
|
||||
|
||||
def _catalog_metadata(config: dict[str, Any]) -> dict[str, Any]:
|
||||
return {
|
||||
"billing_tier": config.get("billing_tier", "free"),
|
||||
"quota_reserve_tokens": config.get("quota_reserve_tokens"),
|
||||
"rpm": config.get("rpm"),
|
||||
"tpm": config.get("tpm"),
|
||||
"anonymous_enabled": config.get("anonymous_enabled", False),
|
||||
"seo_enabled": config.get("seo_enabled", False),
|
||||
"seo_slug": config.get("seo_slug"),
|
||||
"input_cost_per_token": (config.get("litellm_params") or {}).get(
|
||||
"input_cost_per_token"
|
||||
)
|
||||
if isinstance(config.get("litellm_params"), dict)
|
||||
else None,
|
||||
"output_cost_per_token": (config.get("litellm_params") or {}).get(
|
||||
"output_cost_per_token"
|
||||
)
|
||||
if isinstance(config.get("litellm_params"), dict)
|
||||
else None,
|
||||
"is_planner": config.get("is_planner", False),
|
||||
"base_model": _base_model(config),
|
||||
"router_pool_eligible": config.get("router_pool_eligible", True),
|
||||
}
|
||||
|
||||
|
||||
def materialize_global_model_catalog(
|
||||
*,
|
||||
chat_configs: list[dict[str, Any]],
|
||||
image_configs: list[dict[str, Any]],
|
||||
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
|
||||
connections: list[dict[str, Any]] = []
|
||||
models: list[dict[str, Any]] = []
|
||||
connection_id_by_key: dict[tuple[Any, ...], int] = {}
|
||||
next_connection_id = -1
|
||||
|
||||
def add_config(config: dict[str, Any], role: str) -> None:
|
||||
nonlocal next_connection_id
|
||||
if not config.get("id") or not config.get("model_name"):
|
||||
return
|
||||
conn = native_connection_from_config(config)
|
||||
conn["scope"] = "GLOBAL"
|
||||
conn["enabled"] = True
|
||||
key = _connection_key(conn)
|
||||
connection_id = connection_id_by_key.get(key)
|
||||
if connection_id is None:
|
||||
connection_id = next_connection_id
|
||||
next_connection_id -= 1
|
||||
connection_id_by_key[key] = connection_id
|
||||
connections.append(
|
||||
{
|
||||
"id": connection_id,
|
||||
**conn,
|
||||
}
|
||||
)
|
||||
|
||||
model_id = int(config["id"])
|
||||
models.append(
|
||||
{
|
||||
"id": model_id,
|
||||
"connection_id": connection_id,
|
||||
"model_id": config["model_name"],
|
||||
"display_name": config.get("name") or config["model_name"],
|
||||
"source": "MANUAL",
|
||||
"supports_chat": role == "chat",
|
||||
"max_input_tokens": config.get("max_input_tokens"),
|
||||
"supports_image_input": (
|
||||
role == "chat" and bool(config.get("supports_image_input"))
|
||||
),
|
||||
"supports_tools": bool(config.get("supports_tools", False)),
|
||||
"supports_image_generation": role == "image_gen",
|
||||
"capabilities_override": {},
|
||||
"enabled": True,
|
||||
"billing_tier": config.get("billing_tier", "free"),
|
||||
"catalog": _catalog_metadata(config),
|
||||
"role": role,
|
||||
}
|
||||
)
|
||||
|
||||
for cfg in chat_configs:
|
||||
if cfg.get("is_auto_mode"):
|
||||
continue
|
||||
add_config(cfg, "chat")
|
||||
for cfg in image_configs:
|
||||
if cfg.get("is_auto_mode"):
|
||||
continue
|
||||
add_config(cfg, "image_gen")
|
||||
|
||||
# Each virtual connection is server-only. Callers that serialize these
|
||||
# must strip api_key before returning data to clients.
|
||||
return connections, models
|
||||
|
||||
|
||||
__all__ = ["materialize_global_model_catalog"]
|
||||
|
|
@ -20,28 +20,13 @@ from typing import Any
|
|||
from litellm import Router
|
||||
from litellm.utils import ImageResponse
|
||||
|
||||
from app.services.provider_api_base import resolve_api_base
|
||||
from app.services.model_resolver import native_connection_from_config, to_litellm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Special ID for Auto mode - uses router for load balancing
|
||||
IMAGE_GEN_AUTO_MODE_ID = 0
|
||||
|
||||
# Provider mapping for LiteLLM model string construction.
|
||||
# Only includes providers that support image generation.
|
||||
# See: https://docs.litellm.ai/docs/image_generation#supported-providers
|
||||
IMAGE_GEN_PROVIDER_MAP = {
|
||||
"OPENAI": "openai",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"GOOGLE": "gemini", # Google AI Studio
|
||||
"VERTEX_AI": "vertex_ai",
|
||||
"BEDROCK": "bedrock", # AWS Bedrock
|
||||
"RECRAFT": "recraft",
|
||||
"OPENROUTER": "openrouter",
|
||||
"XINFERENCE": "xinference",
|
||||
"NSCALE": "nscale",
|
||||
}
|
||||
|
||||
|
||||
class ImageGenRouterService:
|
||||
"""
|
||||
|
|
@ -153,38 +138,11 @@ class ImageGenRouterService:
|
|||
if not config.get("model_name") or not config.get("api_key"):
|
||||
return None
|
||||
|
||||
# Build model string
|
||||
provider = config.get("provider", "").upper()
|
||||
if config.get("custom_provider"):
|
||||
provider_prefix = config["custom_provider"]
|
||||
else:
|
||||
provider_prefix = IMAGE_GEN_PROVIDER_MAP.get(provider, provider.lower())
|
||||
model_string = f"{provider_prefix}/{config['model_name']}"
|
||||
|
||||
# Build litellm params
|
||||
litellm_params: dict[str, Any] = {
|
||||
"model": model_string,
|
||||
"api_key": config.get("api_key"),
|
||||
}
|
||||
|
||||
# Resolve ``api_base`` so deployments don't silently inherit
|
||||
# ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE`` and 404 against
|
||||
# the wrong provider (see ``provider_api_base`` docstring).
|
||||
api_base = resolve_api_base(
|
||||
provider=provider,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=config.get("api_base"),
|
||||
model_string, resolved_kwargs = to_litellm(
|
||||
native_connection_from_config(config),
|
||||
config["model_name"],
|
||||
)
|
||||
if api_base:
|
||||
litellm_params["api_base"] = api_base
|
||||
|
||||
# Add api_version (required for Azure)
|
||||
if config.get("api_version"):
|
||||
litellm_params["api_version"] = config["api_version"]
|
||||
|
||||
# Add any additional litellm parameters
|
||||
if config.get("litellm_params"):
|
||||
litellm_params.update(config["litellm_params"])
|
||||
litellm_params: dict[str, Any] = {"model": model_string, **resolved_kwargs}
|
||||
|
||||
# All configs use same alias "auto" for unified routing
|
||||
deployment: dict[str, Any] = {
|
||||
|
|
|
|||
257
surfsense_backend/app/services/llm_error_adapter.py
Normal file
257
surfsense_backend/app/services/llm_error_adapter.py
Normal file
|
|
@ -0,0 +1,257 @@
|
|||
"""Normalize provider/LLM exceptions into low-cardinality product categories."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
|
||||
class LLMErrorCategory(StrEnum):
|
||||
RATE_LIMITED = "rate_limited"
|
||||
TIMEOUT = "timeout"
|
||||
PROVIDER_UNAVAILABLE = "provider_unavailable"
|
||||
BAD_GATEWAY = "bad_gateway"
|
||||
CONNECTION_FAILED = "connection_failed"
|
||||
AUTH_FAILED = "auth_failed"
|
||||
PERMISSION_DENIED = "permission_denied"
|
||||
MODEL_NOT_FOUND = "model_not_found"
|
||||
BAD_REQUEST = "bad_request"
|
||||
CONTEXT_LIMIT = "context_limit"
|
||||
RESPONSE_INVALID = "response_invalid"
|
||||
SERVER_ERROR = "server_error"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LLMErrorAdaptation:
|
||||
category: LLMErrorCategory
|
||||
retryable: bool
|
||||
user_message: str
|
||||
provider_status_code: int | None = None
|
||||
provider_error_type: str | None = None
|
||||
|
||||
|
||||
_CATEGORY_MESSAGES: dict[LLMErrorCategory, str] = {
|
||||
LLMErrorCategory.RATE_LIMITED: "LLM rate limit exceeded. Will retry on next sync.",
|
||||
LLMErrorCategory.TIMEOUT: "LLM request timed out. Will retry on next sync.",
|
||||
LLMErrorCategory.PROVIDER_UNAVAILABLE: "LLM service temporarily unavailable. Will retry on next sync.",
|
||||
LLMErrorCategory.BAD_GATEWAY: "LLM gateway error. Will retry on next sync.",
|
||||
LLMErrorCategory.CONNECTION_FAILED: "Could not reach the LLM service. Check network connectivity.",
|
||||
LLMErrorCategory.AUTH_FAILED: "LLM authentication failed. Check your API key.",
|
||||
LLMErrorCategory.PERMISSION_DENIED: "LLM request denied. Check your account permissions.",
|
||||
LLMErrorCategory.MODEL_NOT_FOUND: "Model not found. Check your model configuration.",
|
||||
LLMErrorCategory.BAD_REQUEST: "LLM rejected the request. Document content may be invalid.",
|
||||
LLMErrorCategory.CONTEXT_LIMIT: "Document exceeds the LLM context window even after optimization.",
|
||||
LLMErrorCategory.RESPONSE_INVALID: "LLM returned an invalid response.",
|
||||
LLMErrorCategory.SERVER_ERROR: "LLM internal server error. Will retry on next sync.",
|
||||
LLMErrorCategory.UNKNOWN: "Something went wrong when calling the LLM.",
|
||||
}
|
||||
|
||||
_RETRYABLE_CATEGORIES = {
|
||||
LLMErrorCategory.RATE_LIMITED,
|
||||
LLMErrorCategory.TIMEOUT,
|
||||
LLMErrorCategory.PROVIDER_UNAVAILABLE,
|
||||
LLMErrorCategory.BAD_GATEWAY,
|
||||
LLMErrorCategory.CONNECTION_FAILED,
|
||||
LLMErrorCategory.SERVER_ERROR,
|
||||
}
|
||||
|
||||
_CLASS_NAME_MAP: tuple[tuple[LLMErrorCategory, tuple[str, ...]], ...] = (
|
||||
(
|
||||
LLMErrorCategory.RATE_LIMITED,
|
||||
("RateLimitError", "TooManyRequests", "TooManyRequestsError"),
|
||||
),
|
||||
(LLMErrorCategory.TIMEOUT, ("Timeout", "APITimeoutError", "TimeoutException")),
|
||||
(
|
||||
LLMErrorCategory.PROVIDER_UNAVAILABLE,
|
||||
("ServiceUnavailableError", "ServiceUnavailable"),
|
||||
),
|
||||
(
|
||||
LLMErrorCategory.BAD_GATEWAY,
|
||||
("BadGatewayError", "GatewayTimeoutError"),
|
||||
),
|
||||
(
|
||||
LLMErrorCategory.CONNECTION_FAILED,
|
||||
("APIConnectionError", "ConnectError", "ConnectTimeout", "ReadTimeout"),
|
||||
),
|
||||
(
|
||||
LLMErrorCategory.AUTH_FAILED,
|
||||
("AuthenticationError", "InvalidApiKey", "InvalidAPIKey", "InvalidApiKeyError"),
|
||||
),
|
||||
(LLMErrorCategory.PERMISSION_DENIED, ("PermissionDeniedError", "ForbiddenError")),
|
||||
(LLMErrorCategory.MODEL_NOT_FOUND, ("NotFoundError", "ModelNotFoundError")),
|
||||
(
|
||||
LLMErrorCategory.CONTEXT_LIMIT,
|
||||
("ContextWindowExceeded", "ContextOverflow", "ContextLimit"),
|
||||
),
|
||||
(
|
||||
LLMErrorCategory.RESPONSE_INVALID,
|
||||
("APIResponseValidationError", "ResponseValidationError"),
|
||||
),
|
||||
(
|
||||
LLMErrorCategory.BAD_REQUEST,
|
||||
("BadRequestError", "InvalidRequestError", "UnprocessableEntityError"),
|
||||
),
|
||||
(LLMErrorCategory.SERVER_ERROR, ("InternalServerError",)),
|
||||
)
|
||||
|
||||
|
||||
def _parse_error_payload(message: str) -> dict[str, Any] | None:
|
||||
candidates = [message]
|
||||
first_brace_idx = message.find("{")
|
||||
if first_brace_idx >= 0:
|
||||
candidates.append(message[first_brace_idx:])
|
||||
|
||||
for candidate in candidates:
|
||||
try:
|
||||
parsed = json.loads(candidate)
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
except Exception:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
def _class_names(exc: BaseException) -> tuple[str, ...]:
|
||||
return tuple(cls.__name__ for cls in type(exc).__mro__)
|
||||
|
||||
|
||||
def _category_from_class_name(exc: BaseException) -> LLMErrorCategory | None:
|
||||
names = _class_names(exc)
|
||||
for category, hints in _CLASS_NAME_MAP:
|
||||
if any(any(hint in name for hint in hints) for name in names):
|
||||
return category
|
||||
return None
|
||||
|
||||
|
||||
def _extract_provider_status_code(parsed: dict[str, Any] | None) -> int | None:
|
||||
if not isinstance(parsed, dict):
|
||||
return None
|
||||
candidates: list[Any] = [parsed.get("code"), parsed.get("status")]
|
||||
nested = parsed.get("error")
|
||||
if isinstance(nested, dict):
|
||||
candidates.extend([nested.get("code"), nested.get("status")])
|
||||
for value in candidates:
|
||||
try:
|
||||
if value is None:
|
||||
continue
|
||||
return int(value)
|
||||
except Exception:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
def _extract_provider_error_type(parsed: dict[str, Any] | None) -> str | None:
|
||||
if not isinstance(parsed, dict):
|
||||
return None
|
||||
candidates: list[Any] = [parsed.get("type")]
|
||||
nested = parsed.get("error")
|
||||
if isinstance(nested, dict):
|
||||
candidates.append(nested.get("type"))
|
||||
for value in candidates:
|
||||
if isinstance(value, str) and value:
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def _category_from_provider_payload(
|
||||
status_code: int | None,
|
||||
provider_error_type: str | None,
|
||||
) -> LLMErrorCategory | None:
|
||||
if status_code == 429:
|
||||
return LLMErrorCategory.RATE_LIMITED
|
||||
if status_code == 401:
|
||||
return LLMErrorCategory.AUTH_FAILED
|
||||
if status_code == 403:
|
||||
return LLMErrorCategory.PERMISSION_DENIED
|
||||
if status_code == 404:
|
||||
return LLMErrorCategory.MODEL_NOT_FOUND
|
||||
if status_code in (400, 422):
|
||||
return LLMErrorCategory.BAD_REQUEST
|
||||
if status_code in (502, 504):
|
||||
return LLMErrorCategory.BAD_GATEWAY
|
||||
if status_code == 503:
|
||||
return LLMErrorCategory.PROVIDER_UNAVAILABLE
|
||||
if status_code is not None and status_code >= 500:
|
||||
return LLMErrorCategory.SERVER_ERROR
|
||||
|
||||
normalized_type = (provider_error_type or "").lower()
|
||||
if normalized_type == "rate_limit_error":
|
||||
return LLMErrorCategory.RATE_LIMITED
|
||||
if normalized_type in {
|
||||
"authentication_error",
|
||||
"invalid_api_key",
|
||||
"invalid_api_key_error",
|
||||
}:
|
||||
return LLMErrorCategory.AUTH_FAILED
|
||||
if normalized_type in {"permission_denied", "forbidden"}:
|
||||
return LLMErrorCategory.PERMISSION_DENIED
|
||||
if normalized_type in {"not_found_error", "model_not_found"}:
|
||||
return LLMErrorCategory.MODEL_NOT_FOUND
|
||||
if normalized_type in {"context_length_exceeded", "context_window_exceeded"}:
|
||||
return LLMErrorCategory.CONTEXT_LIMIT
|
||||
return None
|
||||
|
||||
|
||||
def _category_from_message(raw: str) -> LLMErrorCategory | None:
|
||||
lowered = raw.lower()
|
||||
if any(
|
||||
hint in lowered
|
||||
for hint in ("rate limit", "rate-limited", "temporarily rate-limited")
|
||||
):
|
||||
return LLMErrorCategory.RATE_LIMITED
|
||||
if any(
|
||||
hint in lowered
|
||||
for hint in (
|
||||
"invalid api key",
|
||||
"invalid_api_key",
|
||||
"authentication",
|
||||
"unauthorized",
|
||||
"user not found",
|
||||
"api key is expired",
|
||||
"expired api key",
|
||||
)
|
||||
):
|
||||
return LLMErrorCategory.AUTH_FAILED
|
||||
if "forbidden" in lowered or "permission denied" in lowered:
|
||||
return LLMErrorCategory.PERMISSION_DENIED
|
||||
if "model not found" in lowered:
|
||||
return LLMErrorCategory.MODEL_NOT_FOUND
|
||||
if any(
|
||||
hint in lowered
|
||||
for hint in (
|
||||
"context length",
|
||||
"context window",
|
||||
"maximum context",
|
||||
"too many tokens",
|
||||
)
|
||||
):
|
||||
return LLMErrorCategory.CONTEXT_LIMIT
|
||||
return None
|
||||
|
||||
|
||||
def adapt_llm_exception(exc: BaseException) -> LLMErrorAdaptation:
|
||||
raw = str(exc)
|
||||
parsed = _parse_error_payload(raw)
|
||||
status_code = _extract_provider_status_code(parsed)
|
||||
provider_error_type = _extract_provider_error_type(parsed)
|
||||
|
||||
category = (
|
||||
_category_from_provider_payload(status_code, provider_error_type)
|
||||
or _category_from_message(raw)
|
||||
or _category_from_class_name(exc)
|
||||
or LLMErrorCategory.UNKNOWN
|
||||
)
|
||||
return LLMErrorAdaptation(
|
||||
category=category,
|
||||
retryable=category in _RETRYABLE_CATEGORIES,
|
||||
user_message=_CATEGORY_MESSAGES[category],
|
||||
provider_status_code=status_code,
|
||||
provider_error_type=provider_error_type,
|
||||
)
|
||||
|
||||
|
||||
def llm_error_message(exc: BaseException) -> str:
|
||||
return adapt_llm_exception(exc).user_message
|
||||
|
|
@ -30,6 +30,7 @@ from litellm.exceptions import (
|
|||
)
|
||||
from pydantic import Field
|
||||
|
||||
from app.services.model_resolver import native_connection_from_config, to_litellm
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
litellm.json_logs = False
|
||||
|
|
@ -96,53 +97,6 @@ def _sanitize_content(content: Any) -> Any:
|
|||
# Special ID for Auto mode - uses router for load balancing
|
||||
AUTO_MODE_ID = 0
|
||||
|
||||
# Provider mapping for LiteLLM model string construction
|
||||
PROVIDER_MAP = {
|
||||
"OPENAI": "openai",
|
||||
"ANTHROPIC": "anthropic",
|
||||
"GROQ": "groq",
|
||||
"COHERE": "cohere",
|
||||
"GOOGLE": "gemini",
|
||||
"OLLAMA": "ollama_chat",
|
||||
"MISTRAL": "mistral",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"OPENROUTER": "openrouter",
|
||||
"COMETAPI": "cometapi",
|
||||
"XAI": "xai",
|
||||
"BEDROCK": "bedrock",
|
||||
"AWS_BEDROCK": "bedrock", # Legacy support
|
||||
"VERTEX_AI": "vertex_ai",
|
||||
"TOGETHER_AI": "together_ai",
|
||||
"FIREWORKS_AI": "fireworks_ai",
|
||||
"REPLICATE": "replicate",
|
||||
"PERPLEXITY": "perplexity",
|
||||
"ANYSCALE": "anyscale",
|
||||
"DEEPINFRA": "deepinfra",
|
||||
"CEREBRAS": "cerebras",
|
||||
"SAMBANOVA": "sambanova",
|
||||
"AI21": "ai21",
|
||||
"CLOUDFLARE": "cloudflare",
|
||||
"DATABRICKS": "databricks",
|
||||
"DEEPSEEK": "openai",
|
||||
"ALIBABA_QWEN": "openai",
|
||||
"MOONSHOT": "openai",
|
||||
"ZHIPU": "openai",
|
||||
"GITHUB_MODELS": "github",
|
||||
"HUGGINGFACE": "huggingface",
|
||||
"MINIMAX": "openai",
|
||||
"CUSTOM": "custom",
|
||||
}
|
||||
|
||||
|
||||
# ``PROVIDER_DEFAULT_API_BASE`` and ``PROVIDER_KEY_DEFAULT_API_BASE`` were
|
||||
# hoisted to ``app.services.provider_api_base`` so vision and image-gen
|
||||
# call sites can share the exact same defense (OpenRouter / Groq / etc.
|
||||
# 404-ing against an inherited Azure endpoint). Re-exported here for
|
||||
# backward compatibility with any external import.
|
||||
from app.services.provider_api_base import ( # noqa: E402
|
||||
resolve_api_base,
|
||||
)
|
||||
|
||||
|
||||
class LLMRouterService:
|
||||
"""
|
||||
|
|
@ -420,38 +374,11 @@ class LLMRouterService:
|
|||
if not config.get("model_name") or not config.get("api_key"):
|
||||
return None
|
||||
|
||||
# Build model string
|
||||
provider = config.get("provider", "").upper()
|
||||
if config.get("custom_provider"):
|
||||
provider_prefix = config["custom_provider"]
|
||||
model_string = f"{provider_prefix}/{config['model_name']}"
|
||||
else:
|
||||
provider_prefix = PROVIDER_MAP.get(provider, provider.lower())
|
||||
model_string = f"{provider_prefix}/{config['model_name']}"
|
||||
|
||||
# Build litellm params
|
||||
litellm_params = {
|
||||
"model": model_string,
|
||||
"api_key": config.get("api_key"),
|
||||
}
|
||||
|
||||
# Resolve ``api_base``. Config value wins; otherwise apply a
|
||||
# provider-aware default so the deployment does not silently
|
||||
# inherit unrelated env vars (e.g. ``AZURE_API_BASE``) and route
|
||||
# requests to the wrong endpoint. See ``provider_api_base``
|
||||
# docstring for the motivating bug (OpenRouter models 404-ing
|
||||
# against an Azure endpoint).
|
||||
api_base = resolve_api_base(
|
||||
provider=provider,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=config.get("api_base"),
|
||||
model_string, resolved_kwargs = to_litellm(
|
||||
native_connection_from_config(config),
|
||||
config["model_name"],
|
||||
)
|
||||
if api_base:
|
||||
litellm_params["api_base"] = api_base
|
||||
|
||||
# Add any additional litellm parameters
|
||||
if config.get("litellm_params"):
|
||||
litellm_params.update(config["litellm_params"])
|
||||
litellm_params = {"model": model_string, **resolved_kwargs}
|
||||
|
||||
# Extract rate limits if provided
|
||||
deployment = {
|
||||
|
|
|
|||
|
|
@ -6,17 +6,21 @@ from langchain_core.messages import HumanMessage
|
|||
from langchain_litellm import ChatLiteLLM
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.config import config
|
||||
from app.db import NewLLMConfig, SearchSpace
|
||||
from app.db import Model, SearchSpace
|
||||
from app.services.auto_model_pin_service import (
|
||||
auto_model_candidates,
|
||||
choose_auto_model_candidate,
|
||||
)
|
||||
from app.services.llm_router_service import (
|
||||
AUTO_MODE_ID,
|
||||
ChatLiteLLMRouter,
|
||||
LLMRouterService,
|
||||
get_auto_mode_llm,
|
||||
is_auto_mode,
|
||||
)
|
||||
from app.services.provider_api_base import resolve_api_base
|
||||
from app.services.model_capabilities import has_capability
|
||||
from app.services.model_resolver import native_connection_from_config, to_litellm
|
||||
from app.services.token_tracking_service import token_tracker
|
||||
|
||||
# Configure litellm to automatically drop unsupported parameters
|
||||
|
|
@ -66,6 +70,29 @@ def _is_interactive_auth_provider(
|
|||
return False
|
||||
|
||||
|
||||
def _legacy_config_connection(
|
||||
*,
|
||||
provider: str,
|
||||
model_name: str,
|
||||
api_key: str | None,
|
||||
api_base: str | None,
|
||||
custom_provider: str | None = None,
|
||||
litellm_params: dict | None = None,
|
||||
api_version: str | None = None,
|
||||
) -> tuple[str, dict]:
|
||||
cfg = {
|
||||
"provider": provider.lower(),
|
||||
"model_name": model_name,
|
||||
"api_key": api_key,
|
||||
"api_base": api_base,
|
||||
"custom_provider": custom_provider,
|
||||
"api_version": api_version,
|
||||
"litellm_params": litellm_params or {},
|
||||
}
|
||||
conn = native_connection_from_config(cfg)
|
||||
return to_litellm(conn, model_name)
|
||||
|
||||
|
||||
class LLMRole:
|
||||
AGENT = "agent" # For agent/chat operations
|
||||
|
||||
|
|
@ -73,26 +100,16 @@ class LLMRole:
|
|||
def get_global_llm_config(llm_config_id: int) -> dict | None:
|
||||
"""
|
||||
Get a global LLM configuration by ID.
|
||||
Global configs have negative IDs. ID 0 is reserved for Auto mode.
|
||||
Global configs have negative IDs. Auto mode (ID 0) is resolved through the
|
||||
model-candidate pipeline, not this legacy config lookup.
|
||||
|
||||
Args:
|
||||
llm_config_id: The ID of the global config (should be negative or 0 for Auto)
|
||||
llm_config_id: The ID of the global config (must be negative)
|
||||
|
||||
Returns:
|
||||
dict: Global config dictionary or None if not found
|
||||
"""
|
||||
# Auto mode (ID 0) is handled separately via the router
|
||||
if llm_config_id == AUTO_MODE_ID:
|
||||
return {
|
||||
"id": AUTO_MODE_ID,
|
||||
"name": "Auto (Fastest)",
|
||||
"description": "Automatically routes requests across available LLM providers for optimal performance and rate limit handling",
|
||||
"provider": "AUTO",
|
||||
"model_name": "auto",
|
||||
"is_auto_mode": True,
|
||||
}
|
||||
|
||||
if llm_config_id > 0:
|
||||
if llm_config_id >= 0:
|
||||
return None
|
||||
|
||||
for cfg in config.GLOBAL_LLM_CONFIGS:
|
||||
|
|
@ -102,6 +119,55 @@ def get_global_llm_config(llm_config_id: int) -> dict | None:
|
|||
return None
|
||||
|
||||
|
||||
def get_global_model(model_id: int) -> dict | None:
|
||||
return next((m for m in config.GLOBAL_MODELS if m.get("id") == model_id), None)
|
||||
|
||||
|
||||
def get_global_connection(connection_id: int) -> dict | None:
|
||||
return next(
|
||||
(c for c in config.GLOBAL_CONNECTIONS if c.get("id") == connection_id),
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def _has_capability(model: dict | Model, capability: str) -> bool:
|
||||
return has_capability(model, capability)
|
||||
|
||||
|
||||
def _chat_litellm_from_resolved(
|
||||
*,
|
||||
conn: dict | object,
|
||||
model_id: str,
|
||||
disable_streaming: bool = False,
|
||||
) -> tuple[str, dict]:
|
||||
model_string, resolved_kwargs = to_litellm(conn, model_id)
|
||||
litellm_kwargs = {"model": model_string, **resolved_kwargs}
|
||||
if disable_streaming:
|
||||
litellm_kwargs["disable_streaming"] = True
|
||||
return model_string, litellm_kwargs
|
||||
|
||||
|
||||
async def _get_db_model(
|
||||
session: AsyncSession,
|
||||
model_id: int,
|
||||
search_space: SearchSpace,
|
||||
) -> Model | None:
|
||||
result = await session.execute(
|
||||
select(Model)
|
||||
.options(selectinload(Model.connection))
|
||||
.where(Model.id == model_id, Model.enabled.is_(True))
|
||||
)
|
||||
model = result.scalars().first()
|
||||
if not model or not model.connection or not model.connection.enabled:
|
||||
return None
|
||||
conn = model.connection
|
||||
if conn.search_space_id and conn.search_space_id != search_space.id:
|
||||
return None
|
||||
if conn.user_id and conn.user_id != search_space.user_id:
|
||||
return None
|
||||
return model
|
||||
|
||||
|
||||
async def validate_llm_config(
|
||||
provider: str,
|
||||
model_name: str,
|
||||
|
|
@ -146,62 +212,15 @@ async def validate_llm_config(
|
|||
return False, msg
|
||||
|
||||
try:
|
||||
# Build the model string for litellm
|
||||
if custom_provider:
|
||||
model_string = f"{custom_provider}/{model_name}"
|
||||
else:
|
||||
# Map provider enum to litellm format
|
||||
provider_map = {
|
||||
"OPENAI": "openai",
|
||||
"ANTHROPIC": "anthropic",
|
||||
"GROQ": "groq",
|
||||
"COHERE": "cohere",
|
||||
"GOOGLE": "gemini",
|
||||
"OLLAMA": "ollama_chat",
|
||||
"MISTRAL": "mistral",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"OPENROUTER": "openrouter",
|
||||
"COMETAPI": "cometapi",
|
||||
"XAI": "xai",
|
||||
"BEDROCK": "bedrock",
|
||||
"AWS_BEDROCK": "bedrock", # Legacy support (backward compatibility)
|
||||
"VERTEX_AI": "vertex_ai",
|
||||
"TOGETHER_AI": "together_ai",
|
||||
"FIREWORKS_AI": "fireworks_ai",
|
||||
"REPLICATE": "replicate",
|
||||
"PERPLEXITY": "perplexity",
|
||||
"ANYSCALE": "anyscale",
|
||||
"DEEPINFRA": "deepinfra",
|
||||
"CEREBRAS": "cerebras",
|
||||
"SAMBANOVA": "sambanova",
|
||||
"AI21": "ai21",
|
||||
"CLOUDFLARE": "cloudflare",
|
||||
"DATABRICKS": "databricks",
|
||||
# Chinese LLM providers
|
||||
"DEEPSEEK": "openai",
|
||||
"ALIBABA_QWEN": "openai",
|
||||
"MOONSHOT": "openai",
|
||||
"ZHIPU": "openai", # GLM needs special handling
|
||||
"MINIMAX": "openai",
|
||||
"GITHUB_MODELS": "github",
|
||||
}
|
||||
provider_prefix = provider_map.get(provider, provider.lower())
|
||||
model_string = f"{provider_prefix}/{model_name}"
|
||||
|
||||
# Create ChatLiteLLM instance
|
||||
litellm_kwargs = {
|
||||
"model": model_string,
|
||||
"api_key": api_key,
|
||||
"timeout": 30, # Set a timeout for validation
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
if api_base:
|
||||
litellm_kwargs["api_base"] = api_base
|
||||
|
||||
# Add any additional litellm parameters
|
||||
if litellm_params:
|
||||
litellm_kwargs.update(litellm_params)
|
||||
model_string, resolved_kwargs = _legacy_config_connection(
|
||||
provider=provider,
|
||||
model_name=model_name,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
custom_provider=custom_provider,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
litellm_kwargs = {"model": model_string, **resolved_kwargs, "timeout": 30}
|
||||
|
||||
from app.agents.chat.runtime.llm_config import (
|
||||
SanitizedChatLiteLLM,
|
||||
|
|
@ -283,9 +302,9 @@ async def get_search_space_llm_instance(
|
|||
logger.error(f"Search space {search_space_id} not found")
|
||||
return None
|
||||
|
||||
# Get the appropriate LLM config ID based on role
|
||||
# Get the appropriate model binding ID based on role
|
||||
if role == LLMRole.AGENT:
|
||||
llm_config_id = search_space.agent_llm_id
|
||||
llm_config_id = search_space.chat_model_id
|
||||
else:
|
||||
logger.error(f"Invalid LLM role: {role}")
|
||||
return None
|
||||
|
|
@ -294,88 +313,42 @@ async def get_search_space_llm_instance(
|
|||
logger.error(f"No {role} LLM configured for search space {search_space_id}")
|
||||
return None
|
||||
|
||||
# Check for Auto mode (ID 0) - use router for load balancing
|
||||
# Auto mode resolves to one concrete global or BYOK model from the
|
||||
# unified model-connections catalog.
|
||||
if is_auto_mode(llm_config_id):
|
||||
if not LLMRouterService.is_initialized():
|
||||
logger.error(
|
||||
"Auto mode requested but LLM Router not initialized. "
|
||||
"Ensure global_llm_config.yaml exists with valid configs."
|
||||
)
|
||||
candidates = await auto_model_candidates(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
user_id=search_space.user_id,
|
||||
capability="chat",
|
||||
)
|
||||
if not candidates:
|
||||
logger.error("No chat-capable models available for Auto mode")
|
||||
return None
|
||||
llm_config_id = int(
|
||||
choose_auto_model_candidate(candidates, search_space_id)["id"]
|
||||
)
|
||||
|
||||
try:
|
||||
logger.debug(
|
||||
f"Using Auto mode (LLM Router) for search space {search_space_id}, role {role}"
|
||||
)
|
||||
return get_auto_mode_llm(streaming=not disable_streaming)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create ChatLiteLLMRouter: {e}")
|
||||
return None
|
||||
|
||||
# Check if this is a global config (negative ID)
|
||||
# Check if this is a global virtual model (negative ID)
|
||||
if llm_config_id < 0:
|
||||
global_config = get_global_llm_config(llm_config_id)
|
||||
if not global_config:
|
||||
logger.error(f"Global LLM config {llm_config_id} not found")
|
||||
global_model = get_global_model(llm_config_id)
|
||||
if not global_model or not _has_capability(global_model, "chat"):
|
||||
logger.error(f"Global chat model {llm_config_id} not found")
|
||||
return None
|
||||
global_connection = get_global_connection(global_model["connection_id"])
|
||||
if not global_connection:
|
||||
logger.error(
|
||||
"Global connection %s not found for model %s",
|
||||
global_model["connection_id"],
|
||||
llm_config_id,
|
||||
)
|
||||
return None
|
||||
|
||||
# Build model string for global config
|
||||
if global_config.get("custom_provider"):
|
||||
model_string = (
|
||||
f"{global_config['custom_provider']}/{global_config['model_name']}"
|
||||
)
|
||||
else:
|
||||
provider_map = {
|
||||
"OPENAI": "openai",
|
||||
"ANTHROPIC": "anthropic",
|
||||
"GROQ": "groq",
|
||||
"COHERE": "cohere",
|
||||
"GOOGLE": "gemini",
|
||||
"OLLAMA": "ollama_chat",
|
||||
"MISTRAL": "mistral",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"OPENROUTER": "openrouter",
|
||||
"COMETAPI": "cometapi",
|
||||
"XAI": "xai",
|
||||
"BEDROCK": "bedrock",
|
||||
"AWS_BEDROCK": "bedrock",
|
||||
"VERTEX_AI": "vertex_ai",
|
||||
"TOGETHER_AI": "together_ai",
|
||||
"FIREWORKS_AI": "fireworks_ai",
|
||||
"REPLICATE": "replicate",
|
||||
"PERPLEXITY": "perplexity",
|
||||
"ANYSCALE": "anyscale",
|
||||
"DEEPINFRA": "deepinfra",
|
||||
"CEREBRAS": "cerebras",
|
||||
"SAMBANOVA": "sambanova",
|
||||
"AI21": "ai21",
|
||||
"CLOUDFLARE": "cloudflare",
|
||||
"DATABRICKS": "databricks",
|
||||
"DEEPSEEK": "openai",
|
||||
"ALIBABA_QWEN": "openai",
|
||||
"MOONSHOT": "openai",
|
||||
"ZHIPU": "openai",
|
||||
"MINIMAX": "openai",
|
||||
}
|
||||
provider_prefix = provider_map.get(
|
||||
global_config["provider"], global_config["provider"].lower()
|
||||
)
|
||||
model_string = f"{provider_prefix}/{global_config['model_name']}"
|
||||
|
||||
# Create ChatLiteLLM instance from global config
|
||||
litellm_kwargs = {
|
||||
"model": model_string,
|
||||
"api_key": global_config["api_key"],
|
||||
}
|
||||
|
||||
if global_config.get("api_base"):
|
||||
litellm_kwargs["api_base"] = global_config["api_base"]
|
||||
|
||||
if global_config.get("litellm_params"):
|
||||
litellm_kwargs.update(global_config["litellm_params"])
|
||||
|
||||
if disable_streaming:
|
||||
litellm_kwargs["disable_streaming"] = True
|
||||
_, litellm_kwargs = _chat_litellm_from_resolved(
|
||||
conn=global_connection,
|
||||
model_id=global_model["model_id"],
|
||||
disable_streaming=disable_streaming,
|
||||
)
|
||||
|
||||
from app.agents.chat.runtime.llm_config import (
|
||||
SanitizedChatLiteLLM,
|
||||
|
|
@ -383,80 +356,18 @@ async def get_search_space_llm_instance(
|
|||
|
||||
return SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
# Get the LLM configuration from database (NewLLMConfig)
|
||||
result = await session.execute(
|
||||
select(NewLLMConfig).where(
|
||||
NewLLMConfig.id == llm_config_id,
|
||||
NewLLMConfig.search_space_id == search_space_id,
|
||||
)
|
||||
)
|
||||
llm_config = result.scalars().first()
|
||||
|
||||
if not llm_config:
|
||||
model = await _get_db_model(session, llm_config_id, search_space)
|
||||
if not model or not _has_capability(model, "chat"):
|
||||
logger.error(
|
||||
f"LLM config {llm_config_id} not found in search space {search_space_id}"
|
||||
f"Chat model {llm_config_id} not found in search space {search_space_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
# Build the model string for litellm
|
||||
if llm_config.custom_provider:
|
||||
model_string = f"{llm_config.custom_provider}/{llm_config.model_name}"
|
||||
else:
|
||||
# Map provider enum to litellm format
|
||||
provider_map = {
|
||||
"OPENAI": "openai",
|
||||
"ANTHROPIC": "anthropic",
|
||||
"GROQ": "groq",
|
||||
"COHERE": "cohere",
|
||||
"GOOGLE": "gemini",
|
||||
"OLLAMA": "ollama_chat",
|
||||
"MISTRAL": "mistral",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"OPENROUTER": "openrouter",
|
||||
"COMETAPI": "cometapi",
|
||||
"XAI": "xai",
|
||||
"BEDROCK": "bedrock",
|
||||
"AWS_BEDROCK": "bedrock",
|
||||
"VERTEX_AI": "vertex_ai",
|
||||
"TOGETHER_AI": "together_ai",
|
||||
"FIREWORKS_AI": "fireworks_ai",
|
||||
"REPLICATE": "replicate",
|
||||
"PERPLEXITY": "perplexity",
|
||||
"ANYSCALE": "anyscale",
|
||||
"DEEPINFRA": "deepinfra",
|
||||
"CEREBRAS": "cerebras",
|
||||
"SAMBANOVA": "sambanova",
|
||||
"AI21": "ai21",
|
||||
"CLOUDFLARE": "cloudflare",
|
||||
"DATABRICKS": "databricks",
|
||||
"DEEPSEEK": "openai",
|
||||
"ALIBABA_QWEN": "openai",
|
||||
"MOONSHOT": "openai",
|
||||
"ZHIPU": "openai",
|
||||
"MINIMAX": "openai",
|
||||
"GITHUB_MODELS": "github",
|
||||
}
|
||||
provider_prefix = provider_map.get(
|
||||
llm_config.provider.value, llm_config.provider.value.lower()
|
||||
)
|
||||
model_string = f"{provider_prefix}/{llm_config.model_name}"
|
||||
|
||||
# Create ChatLiteLLM instance
|
||||
litellm_kwargs = {
|
||||
"model": model_string,
|
||||
"api_key": llm_config.api_key,
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
if llm_config.api_base:
|
||||
litellm_kwargs["api_base"] = llm_config.api_base
|
||||
|
||||
# Add any additional litellm parameters
|
||||
if llm_config.litellm_params:
|
||||
litellm_kwargs.update(llm_config.litellm_params)
|
||||
|
||||
if disable_streaming:
|
||||
litellm_kwargs["disable_streaming"] = True
|
||||
_, litellm_kwargs = _chat_litellm_from_resolved(
|
||||
conn=model.connection,
|
||||
model_id=model.model_id,
|
||||
disable_streaming=disable_streaming,
|
||||
)
|
||||
|
||||
from app.agents.chat.runtime.llm_config import (
|
||||
SanitizedChatLiteLLM,
|
||||
|
|
@ -474,7 +385,7 @@ async def get_search_space_llm_instance(
|
|||
async def get_agent_llm(
|
||||
session: AsyncSession, search_space_id: int, disable_streaming: bool = False
|
||||
) -> ChatLiteLLM | ChatLiteLLMRouter | None:
|
||||
"""Get the search space's agent LLM instance for chat operations."""
|
||||
"""Get the search space's chat model instance."""
|
||||
return await get_search_space_llm_instance(
|
||||
session,
|
||||
search_space_id,
|
||||
|
|
@ -488,24 +399,17 @@ async def get_vision_llm(
|
|||
) -> ChatLiteLLM | ChatLiteLLMRouter | None:
|
||||
"""Get the search space's vision LLM instance for screenshot analysis.
|
||||
|
||||
Resolves from the dedicated VisionLLMConfig system:
|
||||
- Auto mode (ID 0): VisionLLMRouterService
|
||||
- Global (negative ID): YAML configs
|
||||
- DB (positive ID): VisionLLMConfig table
|
||||
Resolves from the new connection/model role bindings:
|
||||
- Auto mode (ID 0): unified global/BYOK model candidate selection
|
||||
- Global (negative ID): virtual GLOBAL models from YAML
|
||||
- DB (positive ID): Model + Connection tables
|
||||
|
||||
Premium global configs are wrapped in :class:`QuotaCheckedVisionLLM`
|
||||
so each ``ainvoke`` debits the search-space owner's premium credit
|
||||
pool. User-owned BYOK configs and free global configs are returned
|
||||
unwrapped — they don't consume premium credit (issue M).
|
||||
"""
|
||||
from app.db import VisionLLMConfig
|
||||
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
|
||||
from app.services.vision_llm_router_service import (
|
||||
VISION_PROVIDER_MAP,
|
||||
VisionLLMRouterService,
|
||||
get_global_vision_llm_config,
|
||||
is_vision_auto_mode,
|
||||
)
|
||||
|
||||
try:
|
||||
result = await session.execute(
|
||||
|
|
@ -516,64 +420,78 @@ async def get_vision_llm(
|
|||
logger.error(f"Search space {search_space_id} not found")
|
||||
return None
|
||||
|
||||
config_id = search_space.vision_llm_config_id
|
||||
owner_user_id = search_space.user_id
|
||||
|
||||
# Prefer the selected chat model when it is vision-capable.
|
||||
chat_model_id = search_space.chat_model_id
|
||||
if chat_model_id and chat_model_id != AUTO_MODE_ID:
|
||||
if chat_model_id < 0:
|
||||
chat_model = get_global_model(chat_model_id)
|
||||
if chat_model and _has_capability(chat_model, "vision"):
|
||||
global_connection = get_global_connection(
|
||||
chat_model["connection_id"]
|
||||
)
|
||||
if global_connection:
|
||||
model_string, litellm_kwargs = _chat_litellm_from_resolved(
|
||||
conn=global_connection,
|
||||
model_id=chat_model["model_id"],
|
||||
)
|
||||
from app.agents.chat.runtime.llm_config import (
|
||||
SanitizedChatLiteLLM,
|
||||
)
|
||||
|
||||
return SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
else:
|
||||
chat_model = await _get_db_model(session, chat_model_id, search_space)
|
||||
if chat_model and _has_capability(chat_model, "vision"):
|
||||
_, litellm_kwargs = _chat_litellm_from_resolved(
|
||||
conn=chat_model.connection,
|
||||
model_id=chat_model.model_id,
|
||||
)
|
||||
from app.agents.chat.runtime.llm_config import (
|
||||
SanitizedChatLiteLLM,
|
||||
)
|
||||
|
||||
return SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
config_id = search_space.vision_model_id
|
||||
if config_id is None:
|
||||
logger.error(f"No vision LLM configured for search space {search_space_id}")
|
||||
return None
|
||||
|
||||
owner_user_id = search_space.user_id
|
||||
|
||||
if is_vision_auto_mode(config_id):
|
||||
if not VisionLLMRouterService.is_initialized():
|
||||
logger.error(
|
||||
"Vision Auto mode requested but Vision LLM Router not initialized"
|
||||
)
|
||||
return None
|
||||
try:
|
||||
# Auto mode is currently treated as free at the wrapper
|
||||
# level — the underlying router can dispatch to either
|
||||
# premium or free YAML configs but routing decisions are
|
||||
# opaque. If/when we want to bill Auto-routed vision
|
||||
# calls we'd need to thread the resolved deployment's
|
||||
# billing_tier back from the router. For now we keep
|
||||
# parity with chat Auto, which also doesn't pre-classify.
|
||||
return ChatLiteLLMRouter(
|
||||
router=VisionLLMRouterService.get_router(),
|
||||
streaming=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create vision ChatLiteLLMRouter: {e}")
|
||||
if config_id == AUTO_MODE_ID:
|
||||
candidates = await auto_model_candidates(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
user_id=owner_user_id,
|
||||
capability="vision",
|
||||
)
|
||||
if not candidates:
|
||||
logger.error("No vision-capable models available for Auto mode")
|
||||
return None
|
||||
config_id = int(
|
||||
choose_auto_model_candidate(candidates, search_space_id)["id"]
|
||||
)
|
||||
|
||||
if config_id < 0:
|
||||
global_cfg = get_global_vision_llm_config(config_id)
|
||||
if not global_cfg:
|
||||
logger.error(f"Global vision LLM config {config_id} not found")
|
||||
global_model = get_global_model(config_id)
|
||||
if not global_model or not _has_capability(global_model, "vision"):
|
||||
logger.error(f"Global vision model {config_id} not found")
|
||||
return None
|
||||
|
||||
if global_cfg.get("custom_provider"):
|
||||
provider_prefix = global_cfg["custom_provider"]
|
||||
model_string = f"{provider_prefix}/{global_cfg['model_name']}"
|
||||
else:
|
||||
provider_prefix = VISION_PROVIDER_MAP.get(
|
||||
global_cfg["provider"].upper(),
|
||||
global_cfg["provider"].lower(),
|
||||
global_connection = get_global_connection(global_model["connection_id"])
|
||||
if not global_connection:
|
||||
logger.error(
|
||||
"Global connection %s not found for model %s",
|
||||
global_model["connection_id"],
|
||||
config_id,
|
||||
)
|
||||
model_string = f"{provider_prefix}/{global_cfg['model_name']}"
|
||||
return None
|
||||
|
||||
litellm_kwargs = {
|
||||
"model": model_string,
|
||||
"api_key": global_cfg["api_key"],
|
||||
}
|
||||
api_base = resolve_api_base(
|
||||
provider=global_cfg.get("provider"),
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=global_cfg.get("api_base"),
|
||||
model_string, litellm_kwargs = _chat_litellm_from_resolved(
|
||||
conn=global_connection,
|
||||
model_id=global_model["model_id"],
|
||||
)
|
||||
if api_base:
|
||||
litellm_kwargs["api_base"] = api_base
|
||||
if global_cfg.get("litellm_params"):
|
||||
litellm_kwargs.update(global_cfg["litellm_params"])
|
||||
|
||||
from app.agents.chat.runtime.llm_config import (
|
||||
SanitizedChatLiteLLM,
|
||||
|
|
@ -581,7 +499,7 @@ async def get_vision_llm(
|
|||
|
||||
inner_llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
billing_tier = str(global_cfg.get("billing_tier", "free")).lower()
|
||||
billing_tier = str(global_model.get("billing_tier", "free")).lower()
|
||||
if billing_tier == "premium":
|
||||
return QuotaCheckedVisionLLM(
|
||||
inner_llm,
|
||||
|
|
@ -589,47 +507,23 @@ async def get_vision_llm(
|
|||
search_space_id=search_space_id,
|
||||
billing_tier=billing_tier,
|
||||
base_model=model_string,
|
||||
quota_reserve_tokens=global_cfg.get("quota_reserve_tokens"),
|
||||
quota_reserve_tokens=global_model.get("catalog", {}).get(
|
||||
"quota_reserve_tokens"
|
||||
),
|
||||
)
|
||||
return inner_llm
|
||||
|
||||
# User-owned (positive ID) BYOK configs — always free.
|
||||
result = await session.execute(
|
||||
select(VisionLLMConfig).where(
|
||||
VisionLLMConfig.id == config_id,
|
||||
VisionLLMConfig.search_space_id == search_space_id,
|
||||
)
|
||||
)
|
||||
vision_cfg = result.scalars().first()
|
||||
if not vision_cfg:
|
||||
model = await _get_db_model(session, config_id, search_space)
|
||||
if not model or not _has_capability(model, "vision"):
|
||||
logger.error(
|
||||
f"Vision LLM config {config_id} not found in search space {search_space_id}"
|
||||
f"Vision model {config_id} not found in search space {search_space_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
if vision_cfg.custom_provider:
|
||||
provider_prefix = vision_cfg.custom_provider
|
||||
model_string = f"{provider_prefix}/{vision_cfg.model_name}"
|
||||
else:
|
||||
provider_prefix = VISION_PROVIDER_MAP.get(
|
||||
vision_cfg.provider.value.upper(),
|
||||
vision_cfg.provider.value.lower(),
|
||||
)
|
||||
model_string = f"{provider_prefix}/{vision_cfg.model_name}"
|
||||
|
||||
litellm_kwargs = {
|
||||
"model": model_string,
|
||||
"api_key": vision_cfg.api_key,
|
||||
}
|
||||
api_base = resolve_api_base(
|
||||
provider=vision_cfg.provider.value,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=vision_cfg.api_base,
|
||||
_, litellm_kwargs = _chat_litellm_from_resolved(
|
||||
conn=model.connection,
|
||||
model_id=model.model_id,
|
||||
)
|
||||
if api_base:
|
||||
litellm_kwargs["api_base"] = api_base
|
||||
if vision_cfg.litellm_params:
|
||||
litellm_kwargs.update(vision_cfg.litellm_params)
|
||||
|
||||
from app.agents.chat.runtime.llm_config import (
|
||||
SanitizedChatLiteLLM,
|
||||
|
|
|
|||
36
surfsense_backend/app/services/model_capabilities.py
Normal file
36
surfsense_backend/app/services/model_capabilities.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
"""Override-aware model capability lookup."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
CAPABILITY_FIELDS = {
|
||||
"chat": "supports_chat",
|
||||
"vision": "supports_image_input",
|
||||
"image_gen": "supports_image_generation",
|
||||
"tools": "supports_tools",
|
||||
}
|
||||
|
||||
|
||||
def _get_value(model: Any, key: str) -> Any:
|
||||
if isinstance(model, Mapping):
|
||||
return model.get(key)
|
||||
return getattr(model, key, None)
|
||||
|
||||
|
||||
def has_capability(model: Any, capability: str) -> bool:
|
||||
field = CAPABILITY_FIELDS.get(capability)
|
||||
if field is None:
|
||||
return False
|
||||
|
||||
override = _get_value(model, "capabilities_override") or {}
|
||||
if isinstance(override, Mapping) and field in override:
|
||||
return bool(override[field])
|
||||
if isinstance(override, Mapping) and capability in override:
|
||||
return bool(override[capability])
|
||||
|
||||
return bool(_get_value(model, field))
|
||||
|
||||
|
||||
__all__ = ["CAPABILITY_FIELDS", "has_capability"]
|
||||
490
surfsense_backend/app/services/model_connection_service.py
Normal file
490
surfsense_backend/app/services/model_connection_service.py
Normal file
|
|
@ -0,0 +1,490 @@
|
|||
"""Connection verification, model discovery, and capability probing."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import anyio
|
||||
import httpx
|
||||
import litellm
|
||||
|
||||
from app.db import Connection, Model, ModelSource
|
||||
from app.services.model_resolver import ensure_v1, to_litellm
|
||||
from app.services.openrouter_model_normalizer import normalize_openrouter_models
|
||||
from app.services.provider_registry import Transport, provider_label, spec_for
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VERIFY_TIMEOUT_SECONDS = 8.0
|
||||
DISCOVERY_TIMEOUT_SECONDS = 15.0
|
||||
TEST_TIMEOUT_SECONDS = 30.0
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class VerifyResult:
|
||||
status: str
|
||||
ok: bool
|
||||
message: str = ""
|
||||
|
||||
|
||||
class ModelDiscoveryError(Exception):
|
||||
"""User-correctable discovery failure for provider configuration issues."""
|
||||
|
||||
|
||||
def _auth_headers(conn: Connection) -> dict[str, str]:
|
||||
if not conn.api_key:
|
||||
return {}
|
||||
return {"Authorization": f"Bearer {conn.api_key}"}
|
||||
|
||||
|
||||
def _anthropic_headers(conn: Connection) -> dict[str, str]:
|
||||
headers = {"anthropic-version": "2023-06-01"}
|
||||
if conn.api_key:
|
||||
headers["x-api-key"] = conn.api_key
|
||||
return headers
|
||||
|
||||
|
||||
def _base_url_or_default(conn: Connection) -> str | None:
|
||||
if conn.base_url:
|
||||
return conn.base_url.rstrip("/")
|
||||
if conn.provider == "openai":
|
||||
return "https://api.openai.com/v1"
|
||||
if conn.provider == "anthropic":
|
||||
return "https://api.anthropic.com/v1"
|
||||
return spec_for(conn.provider).default_base_url
|
||||
|
||||
|
||||
def _docker_hint(url: str | None, exc_or_status: Any) -> str:
|
||||
raw = str(exc_or_status)
|
||||
if not url:
|
||||
return raw
|
||||
if "localhost" in url or "127.0.0.1" in url:
|
||||
return (
|
||||
f"{raw}. The backend is running inside Docker; localhost means the "
|
||||
"backend container. Use host.docker.internal and make sure the model "
|
||||
"server listens on 0.0.0.0."
|
||||
)
|
||||
if "host.docker.internal" in url and (
|
||||
"refused" in raw.lower() or "connect" in raw.lower()
|
||||
):
|
||||
return (
|
||||
f"{raw}. The host is reachable only if your local model server is "
|
||||
"listening on 0.0.0.0. On Linux Docker, add "
|
||||
"`host.docker.internal:host-gateway` to extra_hosts."
|
||||
)
|
||||
return raw
|
||||
|
||||
|
||||
def _model_test_error(conn: Connection, model_id: str, exc: Exception) -> VerifyResult:
|
||||
provider_name = provider_label(conn.provider)
|
||||
raw = str(exc)
|
||||
normalized = raw.lower()
|
||||
exc_name = exc.__class__.__name__.lower()
|
||||
status_code = getattr(exc, "status_code", None)
|
||||
|
||||
logger.info(
|
||||
"Model test failed for provider=%s model=%s: %s",
|
||||
conn.provider,
|
||||
model_id,
|
||||
raw,
|
||||
)
|
||||
|
||||
if status_code in (401, 403) or "authentication" in exc_name or "401" in normalized:
|
||||
return VerifyResult(
|
||||
"AUTH_FAILED",
|
||||
False,
|
||||
f"Authentication failed. Check your {provider_name} credentials and try again.",
|
||||
)
|
||||
|
||||
if status_code == 404 or "notfound" in exc_name or "not found" in normalized:
|
||||
if conn.provider == "azure":
|
||||
message = (
|
||||
"Azure OpenAI deployment was not found. Check the deployment name, "
|
||||
"API version, and endpoint."
|
||||
)
|
||||
else:
|
||||
message = f"Model '{model_id}' was not found on {provider_name}."
|
||||
return VerifyResult("NOT_FOUND", False, message)
|
||||
|
||||
if status_code == 429 or "ratelimit" in exc_name or "rate limit" in normalized:
|
||||
return VerifyResult(
|
||||
"RATE_LIMITED",
|
||||
False,
|
||||
f"{provider_name} rate limited the model test. Try again later.",
|
||||
)
|
||||
|
||||
if "timeout" in exc_name or "timed out" in normalized:
|
||||
return VerifyResult(
|
||||
"TIMEOUT",
|
||||
False,
|
||||
f"{provider_name} did not respond in time. Check the endpoint and try again.",
|
||||
)
|
||||
|
||||
if "connection" in exc_name or "connect" in normalized:
|
||||
return VerifyResult(
|
||||
"UNREACHABLE",
|
||||
False,
|
||||
_docker_hint(
|
||||
_base_url_or_default(conn),
|
||||
f"Could not reach {provider_name}. Check the endpoint and try again.",
|
||||
),
|
||||
)
|
||||
|
||||
return VerifyResult(
|
||||
"UNREACHABLE",
|
||||
False,
|
||||
f"Could not test model '{model_id}' on {provider_name}. Check the credentials, endpoint, and model name.",
|
||||
)
|
||||
|
||||
|
||||
async def verify_connection(conn: Connection) -> VerifyResult:
|
||||
spec = spec_for(conn.provider)
|
||||
base_url = _base_url_or_default(conn)
|
||||
if spec.base_url_required and not base_url:
|
||||
return VerifyResult("UNREACHABLE", False, "Base URL is required.")
|
||||
|
||||
if spec.transport == Transport.OLLAMA and base_url:
|
||||
url = f"{base_url.rstrip('/')}/api/version"
|
||||
elif spec.discovery in {"openai_models", "openrouter"} and base_url:
|
||||
url = f"{ensure_v1(base_url)}/models"
|
||||
elif spec.discovery == "anthropic_models" and base_url:
|
||||
url = f"{base_url.rstrip('/')}/models"
|
||||
else:
|
||||
return VerifyResult(
|
||||
"OK", True, "Connection uses provider-native authentication."
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=VERIFY_TIMEOUT_SECONDS) as client:
|
||||
headers = (
|
||||
_anthropic_headers(conn)
|
||||
if spec.auth_style == "x-api-key"
|
||||
else _auth_headers(conn)
|
||||
)
|
||||
response = await client.get(url, headers=headers)
|
||||
if response.status_code in (401, 403):
|
||||
return VerifyResult("AUTH_FAILED", False, "Authentication failed.")
|
||||
if response.status_code == 404:
|
||||
if spec.transport == Transport.OLLAMA and url.endswith("/v1/models"):
|
||||
message = "Ollama native API should not use /v1."
|
||||
elif spec.transport == Transport.OPENAI_COMPATIBLE:
|
||||
message = "OpenAI-compatible servers should expose /v1/models."
|
||||
else:
|
||||
message = "Endpoint returned 404."
|
||||
return VerifyResult("NOT_FOUND", False, message)
|
||||
response.raise_for_status()
|
||||
return VerifyResult("OK", True, "Connection verified.")
|
||||
except httpx.ConnectError as exc:
|
||||
return VerifyResult("UNREACHABLE", False, _docker_hint(base_url, exc))
|
||||
except httpx.TimeoutException as exc:
|
||||
return VerifyResult("UNREACHABLE", False, f"Connection timed out: {exc}")
|
||||
except httpx.HTTPError as exc:
|
||||
return VerifyResult("UNREACHABLE", False, _docker_hint(base_url, exc))
|
||||
|
||||
|
||||
def _discovery_error_message(conn: Connection, exc: httpx.HTTPError) -> str:
|
||||
base_url = _base_url_or_default(conn)
|
||||
if isinstance(exc, httpx.HTTPStatusError):
|
||||
status_code = exc.response.status_code
|
||||
if status_code in (401, 403):
|
||||
return "Authentication failed while discovering models."
|
||||
if status_code == 404:
|
||||
spec = spec_for(conn.provider)
|
||||
if spec.transport == Transport.OPENAI_COMPATIBLE:
|
||||
return "OpenAI-compatible servers should expose /v1/models."
|
||||
return "Model discovery endpoint returned 404."
|
||||
return f"Model discovery failed with HTTP {status_code}."
|
||||
if isinstance(exc, httpx.TimeoutException):
|
||||
return f"Model discovery timed out: {exc}"
|
||||
return _docker_hint(base_url, exc)
|
||||
|
||||
|
||||
def _allowlist(conn: Connection) -> set[str]:
|
||||
raw = (conn.extra or {}).get("model_ids") or []
|
||||
return {str(item).strip() for item in raw if str(item).strip()}
|
||||
|
||||
|
||||
def _litellm_info(model_string: str, model_id: str) -> dict[str, Any]:
|
||||
with contextlib.suppress(Exception):
|
||||
info = litellm.get_model_info(model=model_string)
|
||||
if isinstance(info, dict):
|
||||
return info
|
||||
return (
|
||||
litellm.model_cost.get(model_string) or litellm.model_cost.get(model_id) or {}
|
||||
)
|
||||
|
||||
|
||||
def _classify_from_litellm(model_string: str, model_id: str) -> dict[str, Any]:
|
||||
info = _litellm_info(model_string, model_id)
|
||||
mode = info.get("mode")
|
||||
supports_image_input = False
|
||||
supports_tools = False
|
||||
with contextlib.suppress(Exception):
|
||||
supports_image_input = bool(litellm.supports_vision(model=model_string))
|
||||
with contextlib.suppress(Exception):
|
||||
supports_tools = bool(litellm.supports_function_calling(model=model_string))
|
||||
return {
|
||||
"supports_chat": mode in (None, "chat", "completion", "responses"),
|
||||
"max_input_tokens": info.get("max_input_tokens") or info.get("max_tokens"),
|
||||
"supports_image_input": supports_image_input,
|
||||
"supports_tools": supports_tools,
|
||||
"supports_image_generation": mode
|
||||
in {"image_generation", "image_generation_model"},
|
||||
}
|
||||
|
||||
|
||||
def derive_capabilities(
|
||||
conn: Connection, model_id: str, metadata: dict | None = None
|
||||
) -> dict[str, Any]:
|
||||
metadata = metadata or {}
|
||||
spec = spec_for(conn.provider)
|
||||
model_string, _ = to_litellm(conn, model_id)
|
||||
facts = _classify_from_litellm(model_string, model_id)
|
||||
if spec.transport == Transport.OLLAMA:
|
||||
caps = set(metadata.get("capabilities") or [])
|
||||
details = metadata.get("details") or {}
|
||||
facts.update(
|
||||
{
|
||||
"supports_chat": "embedding" not in caps,
|
||||
"supports_image_input": "vision" in caps
|
||||
or facts["supports_image_input"],
|
||||
"supports_tools": "tools" in caps or facts["supports_tools"],
|
||||
"supports_image_generation": False,
|
||||
"max_input_tokens": metadata.get("context_length")
|
||||
or metadata.get("num_ctx")
|
||||
or details.get("context_length")
|
||||
or facts["max_input_tokens"],
|
||||
}
|
||||
)
|
||||
return facts
|
||||
|
||||
|
||||
async def _discover_openai_shaped_models(
|
||||
conn: Connection, base_url: str | None
|
||||
) -> list[dict[str, Any]]:
|
||||
resolved_base_url = base_url or _base_url_or_default(conn)
|
||||
if not resolved_base_url:
|
||||
return []
|
||||
|
||||
url = f"{ensure_v1(resolved_base_url)}/models"
|
||||
async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client:
|
||||
response = await client.get(url, headers=_auth_headers(conn))
|
||||
response.raise_for_status()
|
||||
|
||||
results: list[dict[str, Any]] = []
|
||||
for item in response.json().get("data", []):
|
||||
model_id = item.get("id")
|
||||
if not model_id:
|
||||
continue
|
||||
results.append(
|
||||
{
|
||||
"model_id": model_id,
|
||||
"display_name": item.get("name") or model_id,
|
||||
"source": ModelSource.DISCOVERED,
|
||||
**derive_capabilities(conn, model_id, item),
|
||||
"metadata": item,
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
async def _discover_anthropic_models(conn: Connection) -> list[dict[str, Any]]:
|
||||
base_url = _base_url_or_default(conn)
|
||||
if not base_url:
|
||||
return []
|
||||
|
||||
url = f"{base_url.rstrip('/')}/models"
|
||||
async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client:
|
||||
response = await client.get(url, headers=_anthropic_headers(conn))
|
||||
response.raise_for_status()
|
||||
|
||||
results: list[dict[str, Any]] = []
|
||||
for item in response.json().get("data", []):
|
||||
model_id = item.get("id")
|
||||
if not model_id:
|
||||
continue
|
||||
results.append(
|
||||
{
|
||||
"model_id": model_id,
|
||||
"display_name": item.get("display_name") or model_id,
|
||||
"source": ModelSource.DISCOVERED,
|
||||
**derive_capabilities(conn, model_id, item),
|
||||
"metadata": item,
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
async def _ollama_tags_then_show(conn: Connection) -> list[dict[str, Any]]:
|
||||
if not conn.base_url:
|
||||
return []
|
||||
|
||||
base_url = conn.base_url.rstrip("/")
|
||||
async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client:
|
||||
response = await client.get(f"{base_url}/api/tags", headers=_auth_headers(conn))
|
||||
response.raise_for_status()
|
||||
models = response.json().get("models", [])
|
||||
results: list[dict[str, Any]] = []
|
||||
for item in models:
|
||||
model_id = item.get("model") or item.get("name")
|
||||
if not model_id:
|
||||
continue
|
||||
metadata = dict(item)
|
||||
with contextlib.suppress(Exception):
|
||||
show_response = await client.post(
|
||||
f"{base_url}/api/show",
|
||||
json={"model": model_id},
|
||||
headers=_auth_headers(conn),
|
||||
)
|
||||
show_response.raise_for_status()
|
||||
metadata.update(show_response.json())
|
||||
results.append(
|
||||
{
|
||||
"model_id": model_id,
|
||||
"display_name": item.get("name") or model_id,
|
||||
"source": ModelSource.DISCOVERED,
|
||||
**derive_capabilities(conn, model_id, metadata),
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
async def _openrouter_models(conn: Connection) -> list[dict[str, Any]]:
|
||||
base_url = _base_url_or_default(conn) or "https://openrouter.ai/api/v1"
|
||||
async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client:
|
||||
response = await client.get(
|
||||
f"{ensure_v1(base_url)}/models", headers=_auth_headers(conn)
|
||||
)
|
||||
response.raise_for_status()
|
||||
return normalize_openrouter_models(response.json().get("data", []))
|
||||
|
||||
|
||||
def _litellm_static_models(conn: Connection) -> list[dict[str, Any]]:
|
||||
provider = conn.provider
|
||||
prefix = spec_for(provider).litellm_prefix or provider
|
||||
results: list[dict[str, Any]] = []
|
||||
for model_string, metadata in litellm.model_cost.items():
|
||||
if not isinstance(model_string, str) or not model_string.startswith(
|
||||
f"{prefix}/"
|
||||
):
|
||||
continue
|
||||
model_id = model_string.split("/", 1)[1]
|
||||
results.append(
|
||||
{
|
||||
"model_id": model_id,
|
||||
"display_name": metadata.get("display_name") or model_id,
|
||||
"source": ModelSource.DISCOVERED,
|
||||
**_classify_from_litellm(model_string, model_id),
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
async def _discover_bedrock_models(conn: Connection) -> list[dict[str, Any]]:
|
||||
params = (conn.extra or {}).get("litellm_params", {})
|
||||
region_name = params.get("aws_region_name")
|
||||
if not region_name:
|
||||
return []
|
||||
|
||||
def list_models() -> list[dict[str, Any]]:
|
||||
import os
|
||||
|
||||
import boto3
|
||||
|
||||
if bearer_token := params.get("aws_bearer_token_bedrock"):
|
||||
try:
|
||||
os.environ["AWS_BEARER_TOKEN_BEDROCK"] = bearer_token
|
||||
client = boto3.client("bedrock", region_name=region_name)
|
||||
finally:
|
||||
os.environ.pop("AWS_BEARER_TOKEN_BEDROCK", None)
|
||||
else:
|
||||
client_kwargs: dict[str, str] = {"region_name": region_name}
|
||||
if params.get("aws_access_key_id"):
|
||||
client_kwargs["aws_access_key_id"] = params["aws_access_key_id"]
|
||||
if params.get("aws_secret_access_key"):
|
||||
client_kwargs["aws_secret_access_key"] = params["aws_secret_access_key"]
|
||||
client = boto3.client("bedrock", **client_kwargs)
|
||||
|
||||
response = client.list_foundation_models()
|
||||
results: list[dict[str, Any]] = []
|
||||
for item in response.get("modelSummaries", []):
|
||||
model_id = item.get("modelId")
|
||||
if not model_id:
|
||||
continue
|
||||
input_modalities = set(item.get("inputModalities") or [])
|
||||
output_modalities = set(item.get("outputModalities") or [])
|
||||
results.append(
|
||||
{
|
||||
"model_id": model_id,
|
||||
"display_name": item.get("modelName") or model_id,
|
||||
"source": ModelSource.DISCOVERED,
|
||||
"supports_chat": "TEXT" in input_modalities
|
||||
and "TEXT" in output_modalities,
|
||||
"supports_image_input": "IMAGE" in input_modalities,
|
||||
"supports_tools": None,
|
||||
"supports_image_generation": "IMAGE" in output_modalities,
|
||||
"max_input_tokens": None,
|
||||
"metadata": item,
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
return await anyio.to_thread.run_sync(list_models)
|
||||
|
||||
|
||||
async def discover_models(conn: Connection) -> list[dict[str, Any]]:
|
||||
allowlist = _allowlist(conn)
|
||||
spec = spec_for(conn.provider)
|
||||
|
||||
try:
|
||||
if spec.discovery == "ollama":
|
||||
results = await _ollama_tags_then_show(conn)
|
||||
elif spec.discovery == "openrouter":
|
||||
results = await _openrouter_models(conn)
|
||||
elif spec.discovery == "anthropic_models":
|
||||
results = await _discover_anthropic_models(conn)
|
||||
elif spec.discovery == "openai_models":
|
||||
results = await _discover_openai_shaped_models(conn, conn.base_url)
|
||||
elif spec.discovery == "bedrock_models":
|
||||
results = await _discover_bedrock_models(conn)
|
||||
elif spec.discovery == "static":
|
||||
results = _litellm_static_models(conn)
|
||||
else:
|
||||
results = []
|
||||
except httpx.HTTPError as exc:
|
||||
raise ModelDiscoveryError(_discovery_error_message(conn, exc)) from exc
|
||||
|
||||
if allowlist:
|
||||
results = [item for item in results if item["model_id"] in allowlist]
|
||||
return results
|
||||
|
||||
|
||||
async def test_model(conn: Connection, model: Model) -> VerifyResult:
|
||||
model_string, kwargs = to_litellm(conn, model.model_id)
|
||||
try:
|
||||
await litellm.acompletion(
|
||||
model=model_string,
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
timeout=TEST_TIMEOUT_SECONDS,
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as exc:
|
||||
return _model_test_error(conn, model.model_id, exc)
|
||||
|
||||
model.supports_chat = True
|
||||
return VerifyResult("OK", True, "Model test succeeded.")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ModelDiscoveryError",
|
||||
"VerifyResult",
|
||||
"derive_capabilities",
|
||||
"discover_models",
|
||||
"test_model",
|
||||
"verify_connection",
|
||||
]
|
||||
|
|
@ -12,6 +12,8 @@ from pathlib import Path
|
|||
|
||||
import httpx
|
||||
|
||||
from app.services.openrouter_model_normalizer import normalize_openrouter_models
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OPENROUTER_API_URL = "https://openrouter.ai/api/v1/models"
|
||||
|
|
@ -22,7 +24,7 @@ CACHE_TTL_SECONDS = 86400 # 24 hours
|
|||
_cache: list[dict] | None = None
|
||||
_cache_timestamp: float = 0
|
||||
|
||||
# Maps OpenRouter provider slug → our LiteLLMProvider enum value.
|
||||
# Maps OpenRouter provider slug to native LiteLLM provider prefixes.
|
||||
# Only providers where the model-name part (after the slash) can be
|
||||
# used directly with the native provider's litellm prefix are listed.
|
||||
#
|
||||
|
|
@ -121,26 +123,13 @@ def _process_models(raw_models: list[dict]) -> list[dict]:
|
|||
"""
|
||||
processed: list[dict] = []
|
||||
|
||||
for model in raw_models:
|
||||
model_id: str = model.get("id", "")
|
||||
name: str = model.get("name", "")
|
||||
context_length = model.get("context_length")
|
||||
|
||||
for normalized in normalize_openrouter_models(raw_models):
|
||||
model_id: str = normalized["model_id"]
|
||||
name: str = normalized.get("display_name") or model_id
|
||||
context_length = normalized.get("max_input_tokens")
|
||||
if "/" not in model_id:
|
||||
continue
|
||||
|
||||
if not _is_text_output_model(model):
|
||||
continue
|
||||
|
||||
if not _supports_tool_calling(model):
|
||||
continue
|
||||
|
||||
if not _has_sufficient_context(model):
|
||||
continue
|
||||
|
||||
if not _is_allowed_model(model):
|
||||
continue
|
||||
|
||||
provider_slug, model_name = model_id.split("/", 1)
|
||||
context_window = _format_context_length(context_length)
|
||||
|
||||
|
|
@ -154,19 +143,19 @@ def _process_models(raw_models: list[dict]) -> list[dict]:
|
|||
}
|
||||
)
|
||||
|
||||
# 2) Emit for the native provider when we have a mapping
|
||||
native_provider = OPENROUTER_SLUG_TO_PROVIDER.get(provider_slug)
|
||||
if native_provider:
|
||||
# 2) Emit for the direct provider when we have a mapping
|
||||
direct_provider = OPENROUTER_SLUG_TO_PROVIDER.get(provider_slug)
|
||||
if direct_provider:
|
||||
# Google's Gemini API only serves gemini-* models.
|
||||
# Open-source models like gemma-* are NOT available through it.
|
||||
if native_provider == "GOOGLE" and not model_name.startswith("gemini-"):
|
||||
if direct_provider == "GOOGLE" and not model_name.startswith("gemini-"):
|
||||
continue
|
||||
|
||||
processed.append(
|
||||
{
|
||||
"value": model_name,
|
||||
"label": name,
|
||||
"provider": native_provider,
|
||||
"provider": direct_provider,
|
||||
"context_window": context_window,
|
||||
}
|
||||
)
|
||||
|
|
|
|||
94
surfsense_backend/app/services/model_resolver.py
Normal file
94
surfsense_backend/app/services/model_resolver.py
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
"""Single model-to-LiteLLM resolver.
|
||||
|
||||
All chat, vision, image-generation, validation, and Auto routing paths should
|
||||
turn a Connection + Model into LiteLLM input through this module.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.db import Connection
|
||||
|
||||
from app.services.provider_registry import Transport, spec_for
|
||||
|
||||
|
||||
def ensure_v1(base_url: str | None) -> str | None:
|
||||
if not base_url:
|
||||
return None
|
||||
stripped = base_url.rstrip("/")
|
||||
if stripped.endswith("/v1"):
|
||||
return stripped
|
||||
return f"{stripped}/v1"
|
||||
|
||||
|
||||
def _conn_value(conn: Connection | Mapping[str, Any], key: str) -> Any:
|
||||
if isinstance(conn, Mapping):
|
||||
return conn.get(key)
|
||||
return getattr(conn, key)
|
||||
|
||||
|
||||
def to_litellm(
|
||||
conn: Connection | Mapping[str, Any],
|
||||
model_id: str,
|
||||
) -> tuple[str, dict[str, Any]]:
|
||||
"""Return ``(model_string, litellm_kwargs)`` for any model role."""
|
||||
provider = _conn_value(conn, "provider")
|
||||
base_url = _conn_value(conn, "base_url")
|
||||
api_key = _conn_value(conn, "api_key")
|
||||
extra = _conn_value(conn, "extra") or {}
|
||||
spec = spec_for(provider)
|
||||
|
||||
kwargs: dict[str, Any] = {}
|
||||
if api_key:
|
||||
kwargs["api_key"] = api_key
|
||||
|
||||
prefix = spec.litellm_prefix or str(provider)
|
||||
model_string = f"{prefix}/{model_id}" if prefix else model_id
|
||||
if base_url:
|
||||
api_base = (
|
||||
ensure_v1(base_url)
|
||||
if spec.transport == Transport.OPENAI_COMPATIBLE
|
||||
else base_url.rstrip("/")
|
||||
)
|
||||
kwargs["api_base"] = api_base
|
||||
|
||||
if api_version := extra.get("api_version"):
|
||||
kwargs["api_version"] = api_version
|
||||
kwargs.update(extra.get("litellm_params", {}))
|
||||
kwargs.update(extra.get("kwargs", {}))
|
||||
if provider == "bedrock" and (
|
||||
bearer_token := kwargs.pop("aws_bearer_token_bedrock", None)
|
||||
):
|
||||
kwargs["api_key"] = bearer_token
|
||||
return model_string, kwargs
|
||||
|
||||
|
||||
def native_connection_from_config(config: Mapping[str, Any]) -> dict[str, Any]:
|
||||
"""Build an in-memory connection mapping from a global config."""
|
||||
provider = str(
|
||||
config.get("provider")
|
||||
or config.get("litellm_provider")
|
||||
or config.get("custom_provider")
|
||||
or "openai"
|
||||
)
|
||||
extra: dict[str, Any] = {
|
||||
"litellm_params": config.get("litellm_params") or {},
|
||||
}
|
||||
if config.get("api_version"):
|
||||
extra["api_version"] = config.get("api_version")
|
||||
return {
|
||||
"provider": provider,
|
||||
"base_url": config.get("api_base") or None,
|
||||
"api_key": config.get("api_key") or None,
|
||||
"extra": extra,
|
||||
}
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ensure_v1",
|
||||
"native_connection_from_config",
|
||||
"to_litellm",
|
||||
]
|
||||
|
|
@ -19,6 +19,10 @@ from typing import Any
|
|||
|
||||
import httpx
|
||||
|
||||
from app.services.openrouter_model_normalizer import (
|
||||
is_openrouter_image_model,
|
||||
normalize_openrouter_models,
|
||||
)
|
||||
from app.services.quality_score import (
|
||||
_HEALTH_BLEND_WEIGHT,
|
||||
_HEALTH_ENRICH_CONCURRENCY,
|
||||
|
|
@ -274,7 +278,7 @@ def _generate_configs(
|
|||
|
||||
OpenRouter's own ``openrouter/free`` meta-router is filtered out upstream
|
||||
via ``_EXCLUDED_MODEL_IDS``; we don't expose a redundant auto-select layer
|
||||
because our own Auto (Fastest) pin + 24 h refresh + repair logic already
|
||||
because our own Auto pin + 24 h refresh + repair logic already
|
||||
cover the catalogue-churn case.
|
||||
"""
|
||||
id_offset: int = settings.get("id_offset", -10000)
|
||||
|
|
@ -292,24 +296,16 @@ def _generate_configs(
|
|||
use_default: bool = settings.get("use_default_system_instructions", True)
|
||||
citations_enabled: bool = settings.get("citations_enabled", True)
|
||||
|
||||
text_models = [
|
||||
m
|
||||
for m in raw_models
|
||||
if _is_text_output_model(m)
|
||||
and _supports_tool_calling(m)
|
||||
and _has_sufficient_context(m)
|
||||
and _is_compatible_provider(m)
|
||||
and _is_allowed_model(m)
|
||||
and "/" in m.get("id", "")
|
||||
]
|
||||
text_models = normalize_openrouter_models(raw_models)
|
||||
|
||||
configs: list[dict] = []
|
||||
taken: set[int] = set()
|
||||
now_ts = int(time.time())
|
||||
|
||||
for model in text_models:
|
||||
model_id: str = model["id"]
|
||||
name: str = model.get("name", model_id)
|
||||
for normalized in text_models:
|
||||
model = normalized.get("metadata") or {}
|
||||
model_id: str = normalized["model_id"]
|
||||
name: str = normalized.get("display_name") or model_id
|
||||
tier = _openrouter_tier(model)
|
||||
|
||||
static_q = static_score_or(model, now_ts=now_ts)
|
||||
|
|
@ -323,10 +319,10 @@ def _generate_configs(
|
|||
"seo_enabled": seo_enabled,
|
||||
"seo_slug": None,
|
||||
"quota_reserve_tokens": quota_reserve_tokens,
|
||||
"provider": "OPENROUTER",
|
||||
"provider": "openrouter",
|
||||
"model_name": model_id,
|
||||
"api_key": api_key,
|
||||
"api_base": "",
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
"rpm": free_rpm if tier == "free" else rpm,
|
||||
"tpm": free_tpm if tier == "free" else tpm,
|
||||
"litellm_params": dict(litellm_params),
|
||||
|
|
@ -345,9 +341,9 @@ def _generate_configs(
|
|||
# ``stream_new_chat`` as a fail-fast safety net before the
|
||||
# OpenRouter request would otherwise 404 with
|
||||
# ``"No endpoints found that support image input"``.
|
||||
"supports_image_input": _supports_image_input(model),
|
||||
"supports_image_input": bool(normalized.get("supports_image_input")),
|
||||
_OPENROUTER_DYNAMIC_MARKER: True,
|
||||
# Auto (Fastest) ranking metadata. ``quality_score`` is initialised
|
||||
# Auto ranking metadata. ``quality_score`` is initialised
|
||||
# to the static score and gets re-blended with health on the next
|
||||
# ``_enrich_health`` pass (synchronous on refresh, deferred on cold
|
||||
# start so startup latency is unchanged).
|
||||
|
|
@ -362,11 +358,7 @@ def _generate_configs(
|
|||
return configs
|
||||
|
||||
|
||||
# ID-offset bands used to keep dynamic OpenRouter configs in their own
|
||||
# namespace per surface. Image / vision get separate bands so a single
|
||||
# Postgres-INTEGER cfg ID is unambiguous about which selector it belongs to.
|
||||
_OPENROUTER_IMAGE_ID_OFFSET_DEFAULT = -20000
|
||||
_OPENROUTER_VISION_ID_OFFSET_DEFAULT = -30000
|
||||
|
||||
|
||||
def _generate_image_gen_configs(
|
||||
|
|
@ -400,14 +392,7 @@ def _generate_image_gen_configs(
|
|||
free_rpm: int = settings.get("free_rpm", 20)
|
||||
litellm_params: dict = settings.get("litellm_params") or {}
|
||||
|
||||
image_models = [
|
||||
m
|
||||
for m in raw_models
|
||||
if _is_image_output_model(m)
|
||||
and _is_compatible_provider(m)
|
||||
and _is_allowed_model(m)
|
||||
and "/" in m.get("id", "")
|
||||
]
|
||||
image_models = [m for m in raw_models if is_openrouter_image_model(m)]
|
||||
|
||||
configs: list[dict] = []
|
||||
taken: set[int] = set()
|
||||
|
|
@ -420,14 +405,9 @@ def _generate_image_gen_configs(
|
|||
"id": _stable_config_id(model_id, id_offset, taken),
|
||||
"name": name,
|
||||
"description": f"{name} via OpenRouter (image generation)",
|
||||
"provider": "OPENROUTER",
|
||||
"provider": "openrouter",
|
||||
"model_name": model_id,
|
||||
"api_key": api_key,
|
||||
# Pin to OpenRouter's public base URL so a downstream call site
|
||||
# that forgets ``resolve_api_base`` still doesn't inherit
|
||||
# ``AZURE_OPENAI_ENDPOINT`` and 404 on
|
||||
# ``image_generation/transformation`` (defense-in-depth, see
|
||||
# ``provider_api_base`` docstring).
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
"api_version": None,
|
||||
"rpm": free_rpm if tier == "free" else rpm,
|
||||
|
|
@ -440,93 +420,6 @@ def _generate_image_gen_configs(
|
|||
return configs
|
||||
|
||||
|
||||
def _generate_vision_llm_configs(
|
||||
raw_models: list[dict], settings: dict[str, Any]
|
||||
) -> list[dict]:
|
||||
"""Convert OpenRouter vision-capable LLMs into global vision-LLM config
|
||||
dicts (matches the YAML shape consumed by ``vision_llm_routes``).
|
||||
|
||||
Filter:
|
||||
- architecture.input_modalities contains "image"
|
||||
- architecture.output_modalities contains "text"
|
||||
- compatible provider (excluded slugs blocked)
|
||||
- allowed model id (excluded list blocked)
|
||||
|
||||
Vision-LLM is invoked from the indexer (image extraction during
|
||||
document upload) via ``langchain_litellm.ChatLiteLLM.ainvoke``, so
|
||||
the chat-only ``_supports_tool_calling`` and ``_has_sufficient_context``
|
||||
filters do not apply: a small-context vision model that doesn't
|
||||
advertise tool-calling is still perfectly viable for "describe this
|
||||
image" prompts.
|
||||
"""
|
||||
id_offset: int = int(
|
||||
settings.get("vision_id_offset") or _OPENROUTER_VISION_ID_OFFSET_DEFAULT
|
||||
)
|
||||
api_key: str = settings.get("api_key", "")
|
||||
rpm: int = settings.get("rpm", 200)
|
||||
tpm: int = settings.get("tpm", 1_000_000)
|
||||
free_rpm: int = settings.get("free_rpm", 20)
|
||||
free_tpm: int = settings.get("free_tpm", 100_000)
|
||||
quota_reserve_tokens: int = settings.get("quota_reserve_tokens", 4000)
|
||||
litellm_params: dict = settings.get("litellm_params") or {}
|
||||
|
||||
vision_models = [
|
||||
m
|
||||
for m in raw_models
|
||||
if _is_vision_input_model(m)
|
||||
and _is_compatible_provider(m)
|
||||
and _is_allowed_model(m)
|
||||
and "/" in m.get("id", "")
|
||||
]
|
||||
|
||||
configs: list[dict] = []
|
||||
taken: set[int] = set()
|
||||
for model in vision_models:
|
||||
model_id: str = model["id"]
|
||||
name: str = model.get("name", model_id)
|
||||
tier = _openrouter_tier(model)
|
||||
pricing = model.get("pricing") or {}
|
||||
|
||||
# Capture per-token prices so ``pricing_registration`` can
|
||||
# register them with LiteLLM at startup (and so the cost
|
||||
# estimator in ``estimate_call_reserve_micros`` can resolve
|
||||
# them at reserve time).
|
||||
try:
|
||||
input_cost = float(pricing.get("prompt", 0) or 0)
|
||||
except (TypeError, ValueError):
|
||||
input_cost = 0.0
|
||||
try:
|
||||
output_cost = float(pricing.get("completion", 0) or 0)
|
||||
except (TypeError, ValueError):
|
||||
output_cost = 0.0
|
||||
|
||||
cfg: dict[str, Any] = {
|
||||
"id": _stable_config_id(model_id, id_offset, taken),
|
||||
"name": name,
|
||||
"description": f"{name} via OpenRouter (vision)",
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": model_id,
|
||||
"api_key": api_key,
|
||||
# Pin to OpenRouter's public base URL so a downstream call site
|
||||
# that forgets ``resolve_api_base`` still doesn't inherit
|
||||
# ``AZURE_OPENAI_ENDPOINT`` (defense-in-depth, see
|
||||
# ``provider_api_base`` docstring).
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
"api_version": None,
|
||||
"rpm": free_rpm if tier == "free" else rpm,
|
||||
"tpm": free_tpm if tier == "free" else tpm,
|
||||
"litellm_params": dict(litellm_params),
|
||||
"billing_tier": tier,
|
||||
"quota_reserve_tokens": quota_reserve_tokens,
|
||||
"input_cost_per_token": input_cost or None,
|
||||
"output_cost_per_token": output_cost or None,
|
||||
_OPENROUTER_DYNAMIC_MARKER: True,
|
||||
}
|
||||
configs.append(cfg)
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
class OpenRouterIntegrationService:
|
||||
"""Singleton that manages the dynamic OpenRouter model catalogue."""
|
||||
|
||||
|
|
@ -553,11 +446,9 @@ class OpenRouterIntegrationService:
|
|||
# Cached raw catalogue from the most recent fetch. Image / vision
|
||||
# emitters reuse this to avoid a second network call per surface.
|
||||
self._raw_models: list[dict] = []
|
||||
# Image / vision config caches (only populated when the matching
|
||||
# opt-in flag is true on initialize). Refreshed in lockstep with
|
||||
# the chat catalogue.
|
||||
# Image config cache (only populated when the matching opt-in flag is
|
||||
# true on initialize). Refreshed in lockstep with the chat catalogue.
|
||||
self._image_configs: list[dict] = []
|
||||
self._vision_configs: list[dict] = []
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "OpenRouterIntegrationService":
|
||||
|
|
@ -592,7 +483,7 @@ class OpenRouterIntegrationService:
|
|||
self._configs_by_id = {c["id"]: c for c in self._configs}
|
||||
self._raw_pricing = _extract_raw_pricing(raw_models)
|
||||
|
||||
# Populate image / vision caches when their opt-in flag is set.
|
||||
# Populate image cache when its opt-in flag is set.
|
||||
# Empty otherwise so the accessors return [] without re-running
|
||||
# filters every refresh.
|
||||
if settings.get("image_generation_enabled"):
|
||||
|
|
@ -604,15 +495,6 @@ class OpenRouterIntegrationService:
|
|||
else:
|
||||
self._image_configs = []
|
||||
|
||||
if settings.get("vision_enabled"):
|
||||
self._vision_configs = _generate_vision_llm_configs(raw_models, settings)
|
||||
logger.info(
|
||||
"OpenRouter integration: vision LLM emission ON (%d models)",
|
||||
len(self._vision_configs),
|
||||
)
|
||||
else:
|
||||
self._vision_configs = []
|
||||
|
||||
self._initialized = True
|
||||
|
||||
tier_counts = self._tier_counts(self._configs)
|
||||
|
|
@ -666,9 +548,9 @@ class OpenRouterIntegrationService:
|
|||
self._configs = new_configs
|
||||
self._configs_by_id = new_by_id
|
||||
|
||||
# Image / vision lists are atomic-swapped the same way: filter out
|
||||
# Image list is atomic-swapped the same way: filter out
|
||||
# the previous dynamic entries from the live config list and append
|
||||
# the freshly generated ones. No-ops when the opt-in flag is off.
|
||||
# the freshly generated ones. No-op when the opt-in flag is off.
|
||||
if self._settings.get("image_generation_enabled"):
|
||||
new_image = _generate_image_gen_configs(raw_models, self._settings)
|
||||
static_image = [
|
||||
|
|
@ -679,16 +561,6 @@ class OpenRouterIntegrationService:
|
|||
app_config.GLOBAL_IMAGE_GEN_CONFIGS = static_image + new_image
|
||||
self._image_configs = new_image
|
||||
|
||||
if self._settings.get("vision_enabled"):
|
||||
new_vision = _generate_vision_llm_configs(raw_models, self._settings)
|
||||
static_vision = [
|
||||
c
|
||||
for c in app_config.GLOBAL_VISION_LLM_CONFIGS
|
||||
if not c.get(_OPENROUTER_DYNAMIC_MARKER)
|
||||
]
|
||||
app_config.GLOBAL_VISION_LLM_CONFIGS = static_vision + new_vision
|
||||
self._vision_configs = new_vision
|
||||
|
||||
# Catalogue churn invalidates per-config "recently healthy" credit
|
||||
# earned by the previous turn's preflight. Drop the whole table so
|
||||
# the next turn re-probes against the freshly loaded configs.
|
||||
|
|
@ -710,7 +582,7 @@ class OpenRouterIntegrationService:
|
|||
)
|
||||
|
||||
# Re-blend health scores against the freshly fetched catalogue. Also
|
||||
# re-stamps health for any YAML-curated cfg with provider==OPENROUTER
|
||||
# re-stamps health for any YAML-curated cfg with provider=openrouter
|
||||
# so a hand-picked dead OR model is gated like a dynamic one.
|
||||
await self._enrich_health_safely(static_configs + new_configs, log_summary=True)
|
||||
|
||||
|
|
@ -758,7 +630,7 @@ class OpenRouterIntegrationService:
|
|||
return counts
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Auto (Fastest) health enrichment
|
||||
# Auto health enrichment
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _enrich_health_safely(
|
||||
|
|
@ -787,7 +659,7 @@ class OpenRouterIntegrationService:
|
|||
the entire previous cycle's cache for this run.
|
||||
"""
|
||||
or_cfgs = [
|
||||
c for c in configs if str(c.get("provider", "")).upper() == "OPENROUTER"
|
||||
c for c in configs if str(c.get("provider", "")).lower() == "openrouter"
|
||||
]
|
||||
if not or_cfgs:
|
||||
return
|
||||
|
|
@ -968,17 +840,6 @@ class OpenRouterIntegrationService:
|
|||
"""
|
||||
return list(self._image_configs)
|
||||
|
||||
def get_vision_llm_configs(self) -> list[dict]:
|
||||
"""Return the dynamic OpenRouter vision-LLM configs (empty list
|
||||
when the ``vision_enabled`` flag is off).
|
||||
|
||||
Each entry exposes ``input_cost_per_token`` / ``output_cost_per_token``
|
||||
so ``pricing_registration`` can teach LiteLLM the cost of these
|
||||
models the same way it does for chat — which keeps the billable
|
||||
wrapper able to debit accurate micro-USD on a vision call.
|
||||
"""
|
||||
return list(self._vision_configs)
|
||||
|
||||
def get_raw_pricing(self) -> dict[str, dict[str, str]]:
|
||||
"""Return the cached raw OpenRouter pricing map.
|
||||
|
||||
|
|
|
|||
123
surfsense_backend/app/services/openrouter_model_normalizer.py
Normal file
123
surfsense_backend/app/services/openrouter_model_normalizer.py
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
"""Shared OpenRouter model normalization.
|
||||
|
||||
OpenRouter metadata is richer than generic OpenAI-compatible ``/models``
|
||||
responses. Keep all OpenRouter filtering and capability extraction here so
|
||||
GLOBAL catalogue generation and BYOK discovery agree.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.db import ModelSource
|
||||
|
||||
MIN_CONTEXT_LENGTH = 100_000
|
||||
|
||||
EXCLUDED_PROVIDER_SLUGS = {"amazon"}
|
||||
EXCLUDED_MODEL_IDS: set[str] = {
|
||||
"openai/gpt-4-1106-preview",
|
||||
"openai/gpt-4-turbo-preview",
|
||||
"openai/gpt-4o:extended",
|
||||
"arcee-ai/virtuoso-large",
|
||||
"openai/o3-deep-research",
|
||||
"openai/o4-mini-deep-research",
|
||||
"openrouter/free",
|
||||
}
|
||||
EXCLUDED_MODEL_SUFFIXES: tuple[str, ...] = ("-deep-research",)
|
||||
|
||||
|
||||
def is_text_output_model(model: dict[str, Any]) -> bool:
|
||||
output_mods = model.get("architecture", {}).get("output_modalities", [])
|
||||
return output_mods == ["text"]
|
||||
|
||||
|
||||
def is_image_output_model(model: dict[str, Any]) -> bool:
|
||||
output_mods = model.get("architecture", {}).get("output_modalities", []) or []
|
||||
return "image" in output_mods
|
||||
|
||||
|
||||
def supports_image_input(model: dict[str, Any]) -> bool:
|
||||
input_mods = model.get("architecture", {}).get("input_modalities", []) or []
|
||||
return "image" in input_mods
|
||||
|
||||
|
||||
def supports_tool_calling(model: dict[str, Any]) -> bool:
|
||||
supported = model.get("supported_parameters") or []
|
||||
return "tools" in supported
|
||||
|
||||
|
||||
def has_sufficient_context(model: dict[str, Any]) -> bool:
|
||||
return int(model.get("context_length") or 0) >= MIN_CONTEXT_LENGTH
|
||||
|
||||
|
||||
def is_compatible_provider(model: dict[str, Any]) -> bool:
|
||||
model_id = str(model.get("id") or "")
|
||||
slug = model_id.split("/", 1)[0] if "/" in model_id else ""
|
||||
return slug not in EXCLUDED_PROVIDER_SLUGS
|
||||
|
||||
|
||||
def is_allowed_model(model: dict[str, Any]) -> bool:
|
||||
model_id = str(model.get("id") or "")
|
||||
if model_id in EXCLUDED_MODEL_IDS:
|
||||
return False
|
||||
base_id = model_id.split(":")[0]
|
||||
return not base_id.endswith(EXCLUDED_MODEL_SUFFIXES)
|
||||
|
||||
|
||||
def is_openrouter_chat_model(model: dict[str, Any]) -> bool:
|
||||
return (
|
||||
"/" in str(model.get("id") or "")
|
||||
and is_text_output_model(model)
|
||||
and supports_tool_calling(model)
|
||||
and has_sufficient_context(model)
|
||||
and is_compatible_provider(model)
|
||||
and is_allowed_model(model)
|
||||
)
|
||||
|
||||
|
||||
def is_openrouter_image_model(model: dict[str, Any]) -> bool:
|
||||
return (
|
||||
"/" in str(model.get("id") or "")
|
||||
and is_image_output_model(model)
|
||||
and is_compatible_provider(model)
|
||||
and is_allowed_model(model)
|
||||
)
|
||||
|
||||
|
||||
def normalize_openrouter_models(
|
||||
raw_models: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
normalized: list[dict[str, Any]] = []
|
||||
for model in raw_models:
|
||||
if not is_openrouter_chat_model(model):
|
||||
continue
|
||||
model_id = str(model.get("id") or "")
|
||||
normalized.append(
|
||||
{
|
||||
"model_id": model_id,
|
||||
"display_name": model.get("name") or model_id,
|
||||
"source": ModelSource.DISCOVERED,
|
||||
"supports_chat": True,
|
||||
"max_input_tokens": model.get("context_length"),
|
||||
"supports_image_input": supports_image_input(model),
|
||||
"supports_tools": supports_tool_calling(model),
|
||||
"supports_image_generation": False,
|
||||
"metadata": model,
|
||||
}
|
||||
)
|
||||
return normalized
|
||||
|
||||
|
||||
__all__ = [
|
||||
"MIN_CONTEXT_LENGTH",
|
||||
"has_sufficient_context",
|
||||
"is_allowed_model",
|
||||
"is_compatible_provider",
|
||||
"is_image_output_model",
|
||||
"is_openrouter_chat_model",
|
||||
"is_openrouter_image_model",
|
||||
"is_text_output_model",
|
||||
"normalize_openrouter_models",
|
||||
"supports_image_input",
|
||||
"supports_tool_calling",
|
||||
]
|
||||
|
|
@ -143,21 +143,19 @@ def _register_chat_shape_configs(
|
|||
sample_keys: list[str] = []
|
||||
|
||||
for cfg in configs:
|
||||
provider = str(cfg.get("provider") or "").upper()
|
||||
provider = str(cfg.get("provider") or cfg.get("litellm_provider") or "").lower()
|
||||
model_name = str(cfg.get("model_name") or "").strip()
|
||||
litellm_params = cfg.get("litellm_params") or {}
|
||||
base_model = str(litellm_params.get("base_model") or model_name).strip()
|
||||
|
||||
if provider == "OPENROUTER":
|
||||
if provider == "openrouter":
|
||||
entry = or_pricing.get(model_name)
|
||||
if entry:
|
||||
input_cost = _safe_float(entry.get("prompt"))
|
||||
output_cost = _safe_float(entry.get("completion"))
|
||||
else:
|
||||
# Vision configs from ``_generate_vision_llm_configs``
|
||||
# carry their pricing inline because the OpenRouter
|
||||
# raw-pricing cache is keyed by chat-catalogue model_id;
|
||||
# vision flows pick up the inline values here.
|
||||
# Some dynamically materialized configs can carry pricing
|
||||
# inline when the raw OpenRouter cache has no matching entry.
|
||||
input_cost = _safe_float(cfg.get("input_cost_per_token"))
|
||||
output_cost = _safe_float(cfg.get("output_cost_per_token"))
|
||||
if input_cost == 0.0 and output_cost == 0.0:
|
||||
|
|
@ -189,12 +187,11 @@ def _register_chat_shape_configs(
|
|||
skipped_no_pricing += 1
|
||||
continue
|
||||
aliases = _alias_set_for_yaml(provider, model_name, base_model)
|
||||
provider_slug = "azure" if provider == "AZURE_OPENAI" else provider.lower()
|
||||
count = _register(
|
||||
aliases,
|
||||
input_cost=input_cost,
|
||||
output_cost=output_cost,
|
||||
provider=provider_slug,
|
||||
provider=provider,
|
||||
)
|
||||
if count > 0:
|
||||
registered_models += 1
|
||||
|
|
@ -217,9 +214,8 @@ def _register_chat_shape_configs(
|
|||
def register_pricing_from_global_configs() -> None:
|
||||
"""Register pricing for every known LLM deployment with LiteLLM.
|
||||
|
||||
Walks ``config.GLOBAL_LLM_CONFIGS`` *and* ``config.GLOBAL_VISION_LLM_CONFIGS``
|
||||
so vision calls (during indexing) can resolve cost the same way chat
|
||||
calls do — namely:
|
||||
Walks ``config.GLOBAL_LLM_CONFIGS`` so chat and vision calls can resolve
|
||||
cost from the same chat-shaped deployment configs:
|
||||
|
||||
1. ``OPENROUTER``: pulls the cached raw pricing from
|
||||
``OpenRouterIntegrationService`` (populated during its own
|
||||
|
|
@ -246,10 +242,7 @@ def register_pricing_from_global_configs() -> None:
|
|||
from app.config import config as app_config
|
||||
|
||||
chat_configs: list[dict] = list(getattr(app_config, "GLOBAL_LLM_CONFIGS", []) or [])
|
||||
vision_configs: list[dict] = list(
|
||||
getattr(app_config, "GLOBAL_VISION_LLM_CONFIGS", []) or []
|
||||
)
|
||||
if not chat_configs and not vision_configs:
|
||||
if not chat_configs:
|
||||
logger.info("[PricingRegistration] no global configs to register")
|
||||
return
|
||||
|
||||
|
|
@ -268,7 +261,3 @@ def register_pricing_from_global_configs() -> None:
|
|||
|
||||
if chat_configs:
|
||||
_register_chat_shape_configs(chat_configs, or_pricing=or_pricing, label="chat")
|
||||
if vision_configs:
|
||||
_register_chat_shape_configs(
|
||||
vision_configs, or_pricing=or_pricing, label="vision"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,106 +0,0 @@
|
|||
"""Provider-aware ``api_base`` resolution shared by chat / image-gen / vision.
|
||||
|
||||
LiteLLM falls back to the module-global ``litellm.api_base`` when an
|
||||
individual call doesn't pass one, which silently inherits provider-agnostic
|
||||
env vars like ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE``. Without an
|
||||
explicit ``api_base``, an ``openrouter/<model>`` request can end up at an
|
||||
Azure endpoint and 404 with ``Resource not found`` (real reproducer:
|
||||
[litellm/llms/openrouter/image_generation/transformation.py:242-263] appends
|
||||
``/chat/completions`` to whatever inherited base it gets, regardless of
|
||||
provider).
|
||||
|
||||
The chat router has had this defense for a while
|
||||
(``llm_router_service.py:466-478``). This module hoists the maps + cascade
|
||||
into a tiny standalone helper so vision and image-gen can share the same
|
||||
source of truth without an inter-service circular import.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
PROVIDER_DEFAULT_API_BASE: dict[str, str] = {
|
||||
"openrouter": "https://openrouter.ai/api/v1",
|
||||
"groq": "https://api.groq.com/openai/v1",
|
||||
"mistral": "https://api.mistral.ai/v1",
|
||||
"perplexity": "https://api.perplexity.ai",
|
||||
"xai": "https://api.x.ai/v1",
|
||||
"cerebras": "https://api.cerebras.ai/v1",
|
||||
"deepinfra": "https://api.deepinfra.com/v1/openai",
|
||||
"fireworks_ai": "https://api.fireworks.ai/inference/v1",
|
||||
"together_ai": "https://api.together.xyz/v1",
|
||||
"anyscale": "https://api.endpoints.anyscale.com/v1",
|
||||
"cometapi": "https://api.cometapi.com/v1",
|
||||
"sambanova": "https://api.sambanova.ai/v1",
|
||||
}
|
||||
"""Default ``api_base`` per LiteLLM provider prefix (lowercase).
|
||||
|
||||
Only providers with a well-known, stable public base URL are listed —
|
||||
self-hosted / BYO-endpoint providers (ollama, custom, bedrock, vertex_ai,
|
||||
huggingface, databricks, cloudflare, replicate) are intentionally omitted
|
||||
so their existing config-driven behaviour is preserved."""
|
||||
|
||||
|
||||
PROVIDER_KEY_DEFAULT_API_BASE: dict[str, str] = {
|
||||
"DEEPSEEK": "https://api.deepseek.com/v1",
|
||||
"ALIBABA_QWEN": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
|
||||
"MOONSHOT": "https://api.moonshot.ai/v1",
|
||||
"ZHIPU": "https://open.bigmodel.cn/api/paas/v4",
|
||||
"MINIMAX": "https://api.minimax.io/v1",
|
||||
}
|
||||
"""Canonical provider key (uppercase) → base URL.
|
||||
|
||||
Used when the LiteLLM provider prefix is the generic ``openai`` shim but the
|
||||
config's ``provider`` field tells us which API it actually is (DeepSeek,
|
||||
Alibaba, Moonshot, Zhipu, MiniMax all use the ``openai`` prefix but each
|
||||
has its own base URL)."""
|
||||
|
||||
|
||||
def resolve_api_base(
|
||||
*,
|
||||
provider: str | None,
|
||||
provider_prefix: str | None,
|
||||
config_api_base: str | None,
|
||||
) -> str | None:
|
||||
"""Resolve a non-Azure-leaking ``api_base`` for a deployment.
|
||||
|
||||
Cascade (first non-empty wins):
|
||||
1. The config's own ``api_base`` (whitespace-only treated as missing).
|
||||
2. ``PROVIDER_KEY_DEFAULT_API_BASE[provider.upper()]``.
|
||||
3. ``PROVIDER_DEFAULT_API_BASE[provider_prefix.lower()]``.
|
||||
4. ``None`` — caller should NOT set ``api_base`` and let the LiteLLM
|
||||
provider integration apply its own default (e.g. AzureOpenAI's
|
||||
deployment-derived URL, custom provider's per-deployment URL).
|
||||
|
||||
Args:
|
||||
provider: The config's ``provider`` field (e.g. ``"OPENROUTER"``,
|
||||
``"DEEPSEEK"``). Case-insensitive.
|
||||
provider_prefix: The LiteLLM model-string prefix the same call
|
||||
site builds for the model id (e.g. ``"openrouter"``,
|
||||
``"groq"``). Case-insensitive.
|
||||
config_api_base: ``api_base`` from the global YAML / DB row /
|
||||
OpenRouter dynamic config. Empty / whitespace-only means
|
||||
"missing" — the resolver still applies the cascade.
|
||||
|
||||
Returns:
|
||||
A URL string, or ``None`` if no default applies for this provider.
|
||||
"""
|
||||
if config_api_base and config_api_base.strip():
|
||||
return config_api_base
|
||||
|
||||
if provider:
|
||||
key_default = PROVIDER_KEY_DEFAULT_API_BASE.get(provider.upper())
|
||||
if key_default:
|
||||
return key_default
|
||||
|
||||
if provider_prefix:
|
||||
prefix_default = PROVIDER_DEFAULT_API_BASE.get(provider_prefix.lower())
|
||||
if prefix_default:
|
||||
return prefix_default
|
||||
|
||||
return None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PROVIDER_DEFAULT_API_BASE",
|
||||
"PROVIDER_KEY_DEFAULT_API_BASE",
|
||||
"resolve_api_base",
|
||||
]
|
||||
|
|
@ -49,51 +49,6 @@ import litellm
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Provider-name → LiteLLM model-prefix map.
|
||||
#
|
||||
# Owned here because ``app.services.provider_capabilities`` is the
|
||||
# only edge that's safe to call from ``app.config``'s YAML loader at
|
||||
# class-body init time. ``app.agents.chat.runtime.llm_config`` re-exports
|
||||
# this constant under the historical ``PROVIDER_MAP`` name; placing the
|
||||
# map there directly would re-introduce the
|
||||
# ``app.config -> ... -> deliverables/tools/generate_image ->
|
||||
# app.config`` cycle that prompted the move.
|
||||
_PROVIDER_PREFIX_MAP: dict[str, str] = {
|
||||
"OPENAI": "openai",
|
||||
"ANTHROPIC": "anthropic",
|
||||
"GROQ": "groq",
|
||||
"COHERE": "cohere",
|
||||
"GOOGLE": "gemini",
|
||||
"OLLAMA": "ollama_chat",
|
||||
"MISTRAL": "mistral",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"OPENROUTER": "openrouter",
|
||||
"XAI": "xai",
|
||||
"BEDROCK": "bedrock",
|
||||
"VERTEX_AI": "vertex_ai",
|
||||
"TOGETHER_AI": "together_ai",
|
||||
"FIREWORKS_AI": "fireworks_ai",
|
||||
"DEEPSEEK": "openai",
|
||||
"ALIBABA_QWEN": "openai",
|
||||
"MOONSHOT": "openai",
|
||||
"ZHIPU": "openai",
|
||||
"GITHUB_MODELS": "github",
|
||||
"REPLICATE": "replicate",
|
||||
"PERPLEXITY": "perplexity",
|
||||
"ANYSCALE": "anyscale",
|
||||
"DEEPINFRA": "deepinfra",
|
||||
"CEREBRAS": "cerebras",
|
||||
"SAMBANOVA": "sambanova",
|
||||
"AI21": "ai21",
|
||||
"CLOUDFLARE": "cloudflare",
|
||||
"DATABRICKS": "databricks",
|
||||
"COMETAPI": "cometapi",
|
||||
"HUGGINGFACE": "huggingface",
|
||||
"MINIMAX": "openai",
|
||||
"CUSTOM": "custom",
|
||||
}
|
||||
|
||||
|
||||
def _candidate_model_strings(
|
||||
*,
|
||||
provider: str | None,
|
||||
|
|
@ -123,12 +78,7 @@ def _candidate_model_strings(
|
|||
seen.add(key)
|
||||
candidates.append(key)
|
||||
|
||||
provider_prefix: str | None = None
|
||||
if provider:
|
||||
provider_prefix = _PROVIDER_PREFIX_MAP.get(provider.upper(), provider.lower())
|
||||
if custom_provider:
|
||||
# ``custom_provider`` overrides everything for CUSTOM/proxy setups.
|
||||
provider_prefix = custom_provider
|
||||
provider_prefix = custom_provider or provider
|
||||
|
||||
primary_model = base_model or model_name
|
||||
bare_model = model_name
|
||||
|
|
|
|||
126
surfsense_backend/app/services/provider_registry.py
Normal file
126
surfsense_backend/app/services/provider_registry.py
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
"""Provider registry for model connections.
|
||||
|
||||
The provider string is the single public identity axis. This registry only
|
||||
describes providers whose behavior differs from LiteLLM's native default.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from typing import Literal
|
||||
|
||||
|
||||
class Transport(StrEnum):
|
||||
NATIVE = "NATIVE"
|
||||
OPENAI_COMPATIBLE = "OPENAI_COMPATIBLE"
|
||||
OLLAMA = "OLLAMA"
|
||||
|
||||
|
||||
DiscoveryKind = Literal[
|
||||
"ollama",
|
||||
"openai_models",
|
||||
"anthropic_models",
|
||||
"bedrock_models",
|
||||
"openrouter",
|
||||
"static",
|
||||
"none",
|
||||
]
|
||||
|
||||
AuthStyle = Literal["bearer", "x-api-key", "none", "native"]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProviderSpec:
|
||||
transport: Transport
|
||||
litellm_prefix: str | None
|
||||
discovery: DiscoveryKind
|
||||
default_base_url: str | None
|
||||
base_url_required: bool
|
||||
auth_style: AuthStyle
|
||||
display_name: str | None = None
|
||||
|
||||
|
||||
REGISTRY: dict[str, ProviderSpec] = {
|
||||
"openai": ProviderSpec(
|
||||
Transport.NATIVE, "openai", "openai_models", None, False, "bearer", "OpenAI"
|
||||
),
|
||||
"anthropic": ProviderSpec(
|
||||
Transport.NATIVE,
|
||||
"anthropic",
|
||||
"anthropic_models",
|
||||
None,
|
||||
False,
|
||||
"x-api-key",
|
||||
"Anthropic",
|
||||
),
|
||||
"azure": ProviderSpec(
|
||||
Transport.NATIVE, "azure", "static", None, True, "native", "Azure OpenAI"
|
||||
),
|
||||
"vertex_ai": ProviderSpec(
|
||||
Transport.NATIVE, "vertex_ai", "static", None, False, "native", "Vertex AI"
|
||||
),
|
||||
"bedrock": ProviderSpec(
|
||||
Transport.NATIVE,
|
||||
"bedrock",
|
||||
"bedrock_models",
|
||||
None,
|
||||
False,
|
||||
"native",
|
||||
"Amazon Bedrock",
|
||||
),
|
||||
"openrouter": ProviderSpec(
|
||||
Transport.OPENAI_COMPATIBLE,
|
||||
"openrouter",
|
||||
"openrouter",
|
||||
"https://openrouter.ai/api/v1",
|
||||
False,
|
||||
"bearer",
|
||||
"OpenRouter",
|
||||
),
|
||||
"openai_compatible": ProviderSpec(
|
||||
Transport.OPENAI_COMPATIBLE,
|
||||
"openai",
|
||||
"openai_models",
|
||||
None,
|
||||
True,
|
||||
"bearer",
|
||||
"OpenAI-compatible provider",
|
||||
),
|
||||
"lm_studio": ProviderSpec(
|
||||
Transport.OPENAI_COMPATIBLE,
|
||||
"openai",
|
||||
"openai_models",
|
||||
"http://host.docker.internal:1234/v1",
|
||||
True,
|
||||
"bearer",
|
||||
"LM Studio",
|
||||
),
|
||||
"ollama_chat": ProviderSpec(
|
||||
Transport.OLLAMA,
|
||||
"ollama_chat",
|
||||
"ollama",
|
||||
"http://host.docker.internal:11434",
|
||||
True,
|
||||
"none",
|
||||
"Ollama",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def spec_for(provider: str | None) -> ProviderSpec:
|
||||
provider_key = (provider or "").strip()
|
||||
return REGISTRY.get(provider_key) or ProviderSpec(
|
||||
Transport.NATIVE, provider_key or "openai", "static", None, False, "native"
|
||||
)
|
||||
|
||||
|
||||
def provider_label(provider: str | None) -> str:
|
||||
provider_key = (provider or "").strip()
|
||||
spec = spec_for(provider_key)
|
||||
if spec.display_name:
|
||||
return spec.display_name
|
||||
return provider_key.replace("_", " ").title() if provider_key else "Provider"
|
||||
|
||||
|
||||
__all__ = ["REGISTRY", "ProviderSpec", "Transport", "provider_label", "spec_for"]
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
"""Pure-function quality scoring for Auto (Fastest) model selection.
|
||||
"""Pure-function quality scoring for Auto model selection.
|
||||
|
||||
This module is import-free of any service / request-path dependencies. All
|
||||
numbers are computed once during the OpenRouter refresh tick (or YAML load)
|
||||
|
|
@ -108,25 +108,23 @@ PROVIDER_PRESTIGE_OR: dict[str, int] = {
|
|||
|
||||
# YAML provider field (the upstream API shape the operator selected).
|
||||
PROVIDER_PRESTIGE_YAML: dict[str, int] = {
|
||||
"AZURE_OPENAI": 50,
|
||||
"OPENAI": 50,
|
||||
"ANTHROPIC": 50,
|
||||
"GOOGLE": 50,
|
||||
"VERTEX_AI": 50,
|
||||
"GEMINI": 50,
|
||||
"XAI": 50,
|
||||
"MISTRAL": 38,
|
||||
"DEEPSEEK": 38,
|
||||
"COHERE": 38,
|
||||
"GROQ": 30,
|
||||
"TOGETHER_AI": 28,
|
||||
"FIREWORKS_AI": 28,
|
||||
"PERPLEXITY": 28,
|
||||
"MINIMAX": 28,
|
||||
"BEDROCK": 28,
|
||||
"OPENROUTER": 25,
|
||||
"OLLAMA": 12,
|
||||
"CUSTOM": 12,
|
||||
"azure": 50,
|
||||
"openai": 50,
|
||||
"anthropic": 50,
|
||||
"gemini": 50,
|
||||
"vertex_ai": 50,
|
||||
"xai": 50,
|
||||
"mistral": 38,
|
||||
"deepseek": 38,
|
||||
"cohere": 38,
|
||||
"groq": 30,
|
||||
"together_ai": 28,
|
||||
"fireworks_ai": 28,
|
||||
"perplexity": 28,
|
||||
"bedrock": 28,
|
||||
"openrouter": 25,
|
||||
"ollama_chat": 12,
|
||||
"custom": 12,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -275,7 +273,7 @@ def static_score_yaml(cfg: dict) -> int:
|
|||
listed this model. Pricing / context fall through to lazy ``litellm``
|
||||
lookups; failures are silent (we just lose those sub-points).
|
||||
"""
|
||||
provider = str(cfg.get("provider", "")).upper()
|
||||
provider = str(cfg.get("provider") or cfg.get("litellm_provider") or "").lower()
|
||||
base = PROVIDER_PRESTIGE_YAML.get(provider, 15)
|
||||
|
||||
model_name = cfg.get("model_name") or ""
|
||||
|
|
|
|||
|
|
@ -40,6 +40,10 @@ class TokenCallRecord:
|
|||
total_tokens: int
|
||||
cost_micros: int = 0
|
||||
call_kind: str = "chat"
|
||||
model_ref: str | None = None
|
||||
model_id: str | None = None
|
||||
display_name: str | None = None
|
||||
provider: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -47,6 +51,24 @@ class TurnTokenAccumulator:
|
|||
"""Accumulates token usage across all LLM calls within a single user turn."""
|
||||
|
||||
calls: list[TokenCallRecord] = field(default_factory=list)
|
||||
model_metadata: dict[str, dict[str, str | None]] = field(default_factory=dict)
|
||||
|
||||
def register_model_metadata(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
model_ref: str | None,
|
||||
model_id: str | None,
|
||||
display_name: str | None,
|
||||
provider: str | None,
|
||||
) -> None:
|
||||
"""Attach resolved model metadata for later LiteLLM callback attribution."""
|
||||
self.model_metadata[model] = {
|
||||
"model_ref": model_ref,
|
||||
"model_id": model_id,
|
||||
"display_name": display_name,
|
||||
"provider": provider,
|
||||
}
|
||||
|
||||
def add(
|
||||
self,
|
||||
|
|
@ -57,9 +79,14 @@ class TurnTokenAccumulator:
|
|||
cost_micros: int = 0,
|
||||
call_kind: str = "chat",
|
||||
) -> None:
|
||||
metadata = self.model_metadata.get(model, {})
|
||||
self.calls.append(
|
||||
TokenCallRecord(
|
||||
model=model,
|
||||
model_ref=metadata.get("model_ref"),
|
||||
model_id=metadata.get("model_id"),
|
||||
display_name=metadata.get("display_name"),
|
||||
provider=metadata.get("provider"),
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
|
|
@ -68,13 +95,18 @@ class TurnTokenAccumulator:
|
|||
)
|
||||
)
|
||||
|
||||
def per_message_summary(self) -> dict[str, dict[str, int]]:
|
||||
def per_message_summary(self) -> dict[str, dict[str, Any]]:
|
||||
"""Return token counts (and cost) grouped by model name."""
|
||||
by_model: dict[str, dict[str, int]] = {}
|
||||
by_model: dict[str, dict[str, Any]] = {}
|
||||
for c in self.calls:
|
||||
entry = by_model.setdefault(
|
||||
c.model,
|
||||
{
|
||||
"model": c.model,
|
||||
"model_ref": c.model_ref,
|
||||
"model_id": c.model_id,
|
||||
"display_name": c.display_name,
|
||||
"provider": c.provider,
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
|
|
@ -142,6 +174,27 @@ def get_current_accumulator() -> TurnTokenAccumulator | None:
|
|||
return _turn_accumulator.get()
|
||||
|
||||
|
||||
def register_model_usage_metadata(
|
||||
*,
|
||||
model: str,
|
||||
model_ref: str | None,
|
||||
model_id: str | None,
|
||||
display_name: str | None,
|
||||
provider: str | None,
|
||||
) -> None:
|
||||
"""Register resolved model metadata with the current turn, if one exists."""
|
||||
acc = _turn_accumulator.get()
|
||||
if acc is None:
|
||||
return
|
||||
acc.register_model_metadata(
|
||||
model=model,
|
||||
model_ref=model_ref,
|
||||
model_id=model_id,
|
||||
display_name=display_name,
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def scoped_turn() -> AsyncIterator[TurnTokenAccumulator]:
|
||||
"""Async context manager that scopes a fresh ``TurnTokenAccumulator``
|
||||
|
|
|
|||
|
|
@ -1,201 +0,0 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from litellm import Router
|
||||
|
||||
from app.services.provider_api_base import resolve_api_base
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VISION_AUTO_MODE_ID = 0
|
||||
|
||||
VISION_PROVIDER_MAP = {
|
||||
"OPENAI": "openai",
|
||||
"ANTHROPIC": "anthropic",
|
||||
"GOOGLE": "gemini",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"VERTEX_AI": "vertex_ai",
|
||||
"BEDROCK": "bedrock",
|
||||
"XAI": "xai",
|
||||
"OPENROUTER": "openrouter",
|
||||
"OLLAMA": "ollama_chat",
|
||||
"GROQ": "groq",
|
||||
"TOGETHER_AI": "together_ai",
|
||||
"FIREWORKS_AI": "fireworks_ai",
|
||||
"DEEPSEEK": "openai",
|
||||
"MISTRAL": "mistral",
|
||||
"CUSTOM": "custom",
|
||||
}
|
||||
|
||||
|
||||
class VisionLLMRouterService:
|
||||
_instance = None
|
||||
_router: Router | None = None
|
||||
_model_list: list[dict] = []
|
||||
_router_settings: dict = {}
|
||||
_initialized: bool = False
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "VisionLLMRouterService":
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def initialize(
|
||||
cls,
|
||||
global_configs: list[dict],
|
||||
router_settings: dict | None = None,
|
||||
) -> None:
|
||||
instance = cls.get_instance()
|
||||
|
||||
if instance._initialized:
|
||||
logger.debug("Vision LLM Router already initialized, skipping")
|
||||
return
|
||||
|
||||
model_list = []
|
||||
for config in global_configs:
|
||||
deployment = cls._config_to_deployment(config)
|
||||
if deployment:
|
||||
model_list.append(deployment)
|
||||
|
||||
if not model_list:
|
||||
logger.warning(
|
||||
"No valid vision LLM configs found for router initialization"
|
||||
)
|
||||
return
|
||||
|
||||
instance._model_list = model_list
|
||||
instance._router_settings = router_settings or {}
|
||||
|
||||
default_settings = {
|
||||
"routing_strategy": "usage-based-routing",
|
||||
"num_retries": 3,
|
||||
"allowed_fails": 3,
|
||||
"cooldown_time": 60,
|
||||
"retry_after": 5,
|
||||
}
|
||||
|
||||
final_settings = {**default_settings, **instance._router_settings}
|
||||
|
||||
try:
|
||||
instance._router = Router(
|
||||
model_list=model_list,
|
||||
routing_strategy=final_settings.get(
|
||||
"routing_strategy", "usage-based-routing"
|
||||
),
|
||||
num_retries=final_settings.get("num_retries", 3),
|
||||
allowed_fails=final_settings.get("allowed_fails", 3),
|
||||
cooldown_time=final_settings.get("cooldown_time", 60),
|
||||
set_verbose=False,
|
||||
)
|
||||
instance._initialized = True
|
||||
logger.info(
|
||||
"Vision LLM Router initialized with %d deployments, strategy: %s",
|
||||
len(model_list),
|
||||
final_settings.get("routing_strategy"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize Vision LLM Router: {e}")
|
||||
instance._router = None
|
||||
|
||||
@classmethod
|
||||
def _config_to_deployment(cls, config: dict) -> dict | None:
|
||||
try:
|
||||
if not config.get("model_name") or not config.get("api_key"):
|
||||
return None
|
||||
|
||||
provider = config.get("provider", "").upper()
|
||||
if config.get("custom_provider"):
|
||||
provider_prefix = config["custom_provider"]
|
||||
model_string = f"{provider_prefix}/{config['model_name']}"
|
||||
else:
|
||||
provider_prefix = VISION_PROVIDER_MAP.get(provider, provider.lower())
|
||||
model_string = f"{provider_prefix}/{config['model_name']}"
|
||||
|
||||
litellm_params: dict[str, Any] = {
|
||||
"model": model_string,
|
||||
"api_key": config.get("api_key"),
|
||||
}
|
||||
|
||||
api_base = resolve_api_base(
|
||||
provider=provider,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=config.get("api_base"),
|
||||
)
|
||||
if api_base:
|
||||
litellm_params["api_base"] = api_base
|
||||
|
||||
if config.get("api_version"):
|
||||
litellm_params["api_version"] = config["api_version"]
|
||||
|
||||
if config.get("litellm_params"):
|
||||
litellm_params.update(config["litellm_params"])
|
||||
|
||||
deployment: dict[str, Any] = {
|
||||
"model_name": "auto",
|
||||
"litellm_params": litellm_params,
|
||||
}
|
||||
|
||||
if config.get("rpm"):
|
||||
deployment["rpm"] = config["rpm"]
|
||||
if config.get("tpm"):
|
||||
deployment["tpm"] = config["tpm"]
|
||||
|
||||
return deployment
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to convert vision config to deployment: {e}")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_router(cls) -> Router | None:
|
||||
instance = cls.get_instance()
|
||||
return instance._router
|
||||
|
||||
@classmethod
|
||||
def is_initialized(cls) -> bool:
|
||||
instance = cls.get_instance()
|
||||
return instance._initialized and instance._router is not None
|
||||
|
||||
@classmethod
|
||||
def get_model_count(cls) -> int:
|
||||
instance = cls.get_instance()
|
||||
return len(instance._model_list)
|
||||
|
||||
|
||||
def is_vision_auto_mode(config_id: int | None) -> bool:
|
||||
return config_id == VISION_AUTO_MODE_ID
|
||||
|
||||
|
||||
def build_vision_model_string(
|
||||
provider: str, model_name: str, custom_provider: str | None
|
||||
) -> str:
|
||||
if custom_provider:
|
||||
return f"{custom_provider}/{model_name}"
|
||||
prefix = VISION_PROVIDER_MAP.get(provider.upper(), provider.lower())
|
||||
return f"{prefix}/{model_name}"
|
||||
|
||||
|
||||
def get_global_vision_llm_config(config_id: int) -> dict | None:
|
||||
from app.config import config
|
||||
|
||||
if config_id == VISION_AUTO_MODE_ID:
|
||||
return {
|
||||
"id": VISION_AUTO_MODE_ID,
|
||||
"name": "Auto (Fastest)",
|
||||
"provider": "AUTO",
|
||||
"model_name": "auto",
|
||||
"is_auto_mode": True,
|
||||
}
|
||||
if config_id > 0:
|
||||
return None
|
||||
for cfg in config.GLOBAL_VISION_LLM_CONFIGS:
|
||||
if cfg.get("id") == config_id:
|
||||
return cfg
|
||||
return None
|
||||
|
|
@ -1,134 +0,0 @@
|
|||
"""
|
||||
Service for fetching and caching the vision-capable model list.
|
||||
|
||||
Reuses the same OpenRouter public API and local fallback as the LLM model
|
||||
list service, but filters for models that accept image input.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OPENROUTER_API_URL = "https://openrouter.ai/api/v1/models"
|
||||
FALLBACK_FILE = (
|
||||
Path(__file__).parent.parent / "config" / "vision_model_list_fallback.json"
|
||||
)
|
||||
CACHE_TTL_SECONDS = 86400 # 24 hours
|
||||
|
||||
_cache: list[dict] | None = None
|
||||
_cache_timestamp: float = 0
|
||||
|
||||
OPENROUTER_SLUG_TO_VISION_PROVIDER: dict[str, str] = {
|
||||
"openai": "OPENAI",
|
||||
"anthropic": "ANTHROPIC",
|
||||
"google": "GOOGLE",
|
||||
"mistralai": "MISTRAL",
|
||||
"x-ai": "XAI",
|
||||
}
|
||||
|
||||
|
||||
def _format_context_length(length: int | None) -> str | None:
|
||||
if not length:
|
||||
return None
|
||||
if length >= 1_000_000:
|
||||
return f"{length / 1_000_000:g}M"
|
||||
if length >= 1_000:
|
||||
return f"{length / 1_000:g}K"
|
||||
return str(length)
|
||||
|
||||
|
||||
async def _fetch_from_openrouter() -> list[dict] | None:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=15) as client:
|
||||
response = await client.get(OPENROUTER_API_URL)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("data", [])
|
||||
except Exception as e:
|
||||
logger.warning("Failed to fetch from OpenRouter API for vision models: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
def _load_fallback() -> list[dict]:
|
||||
try:
|
||||
with open(FALLBACK_FILE, encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.error("Failed to load vision model fallback list: %s", e)
|
||||
return []
|
||||
|
||||
|
||||
def _is_vision_model(model: dict) -> bool:
|
||||
"""Return True if the model accepts image input and outputs text."""
|
||||
arch = model.get("architecture", {})
|
||||
input_mods = arch.get("input_modalities", [])
|
||||
output_mods = arch.get("output_modalities", [])
|
||||
return "image" in input_mods and "text" in output_mods
|
||||
|
||||
|
||||
def _process_vision_models(raw_models: list[dict]) -> list[dict]:
|
||||
processed: list[dict] = []
|
||||
|
||||
for model in raw_models:
|
||||
model_id: str = model.get("id", "")
|
||||
name: str = model.get("name", "")
|
||||
context_length = model.get("context_length")
|
||||
|
||||
if "/" not in model_id:
|
||||
continue
|
||||
|
||||
if not _is_vision_model(model):
|
||||
continue
|
||||
|
||||
provider_slug, model_name = model_id.split("/", 1)
|
||||
context_window = _format_context_length(context_length)
|
||||
|
||||
processed.append(
|
||||
{
|
||||
"value": model_id,
|
||||
"label": name,
|
||||
"provider": "OPENROUTER",
|
||||
"context_window": context_window,
|
||||
}
|
||||
)
|
||||
|
||||
native_provider = OPENROUTER_SLUG_TO_VISION_PROVIDER.get(provider_slug)
|
||||
if native_provider:
|
||||
if native_provider == "GOOGLE" and not model_name.startswith("gemini-"):
|
||||
continue
|
||||
|
||||
processed.append(
|
||||
{
|
||||
"value": model_name,
|
||||
"label": name,
|
||||
"provider": native_provider,
|
||||
"context_window": context_window,
|
||||
}
|
||||
)
|
||||
|
||||
return processed
|
||||
|
||||
|
||||
async def get_vision_model_list() -> list[dict]:
|
||||
global _cache, _cache_timestamp
|
||||
|
||||
if _cache is not None and (time.time() - _cache_timestamp) < CACHE_TTL_SECONDS:
|
||||
return _cache
|
||||
|
||||
raw_models = await _fetch_from_openrouter()
|
||||
|
||||
if raw_models is None:
|
||||
logger.info("Using fallback vision model list")
|
||||
return _load_fallback()
|
||||
|
||||
processed = _process_vision_models(raw_models)
|
||||
|
||||
_cache = processed
|
||||
_cache_timestamp = time.time()
|
||||
|
||||
return processed
|
||||
88
surfsense_backend/app/tasks/chat/llm_history_normalizer.py
Normal file
88
surfsense_backend/app/tasks/chat/llm_history_normalizer.py
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
"""Convert persisted chat content into provider-safe LangChain history.
|
||||
|
||||
Assistant UI parts are a UI/storage shape, not an LLM prompt shape. This module
|
||||
extracts only model-safe content before prior turns are replayed to a provider.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
_USER_CONTENT_TYPES = {"text", "image", "image_url"}
|
||||
|
||||
|
||||
def _text_from_block(block: dict[str, Any]) -> str:
|
||||
value = block.get("text") or block.get("content") or ""
|
||||
return value if isinstance(value, str) else ""
|
||||
|
||||
|
||||
def assistant_content_to_llm_text(content: Any) -> str:
|
||||
"""Return visible assistant text, dropping reasoning/UI/provider blocks."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, dict):
|
||||
return _text_from_block(content)
|
||||
if not isinstance(content, list):
|
||||
return ""
|
||||
|
||||
text_chunks: list[str] = []
|
||||
for block in content:
|
||||
if isinstance(block, str):
|
||||
if block:
|
||||
text_chunks.append(block)
|
||||
continue
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
if block.get("type") == "text":
|
||||
text = _text_from_block(block)
|
||||
if text:
|
||||
text_chunks.append(text)
|
||||
return "\n".join(text_chunks)
|
||||
|
||||
|
||||
def user_content_to_llm_content(
|
||||
content: Any,
|
||||
*,
|
||||
allow_images: bool = True,
|
||||
) -> str | list[dict[str, Any]]:
|
||||
"""Return provider-safe user text/image content for LangChain."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, dict):
|
||||
return _text_from_block(content)
|
||||
if not isinstance(content, list):
|
||||
return ""
|
||||
|
||||
parts: list[dict[str, Any]] = []
|
||||
text_chunks: list[str] = []
|
||||
for block in content:
|
||||
if isinstance(block, str):
|
||||
if block:
|
||||
text_chunks.append(block)
|
||||
continue
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
block_type = block.get("type")
|
||||
if block_type not in _USER_CONTENT_TYPES:
|
||||
continue
|
||||
if block_type == "text":
|
||||
text = _text_from_block(block)
|
||||
if text:
|
||||
parts.append({"type": "text", "text": text})
|
||||
text_chunks.append(text)
|
||||
elif allow_images and block_type == "image":
|
||||
image = block.get("image")
|
||||
if isinstance(image, str) and image.startswith("data:"):
|
||||
parts.append({"type": "image_url", "image_url": {"url": image}})
|
||||
elif allow_images and block_type == "image_url":
|
||||
image_url = block.get("image_url")
|
||||
if isinstance(image_url, dict):
|
||||
url = image_url.get("url")
|
||||
if isinstance(url, str) and url.startswith("data:"):
|
||||
parts.append({"type": "image_url", "image_url": {"url": url}})
|
||||
elif isinstance(image_url, str) and image_url.startswith("data:"):
|
||||
parts.append({"type": "image_url", "image_url": {"url": image_url}})
|
||||
|
||||
if allow_images and any(part.get("type") == "image_url" for part in parts):
|
||||
return parts
|
||||
return "\n".join(text_chunks)
|
||||
89
surfsense_backend/app/tasks/chat/message_parts_normalizer.py
Normal file
89
surfsense_backend/app/tasks/chat/message_parts_normalizer.py
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
"""Normalize final LangChain assistant messages into assistant-ui parts.
|
||||
|
||||
Live streaming remains the primary source for rich, incremental UI state.
|
||||
This module is only used after the graph has finished so refresh persistence
|
||||
does not depend on provider-specific streaming chunk shapes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
|
||||
def _text_from_content(content: Any) -> str:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if not isinstance(content, list):
|
||||
return ""
|
||||
|
||||
text_parts: list[str] = []
|
||||
for block in content:
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
if block.get("type") != "text":
|
||||
continue
|
||||
value = block.get("text") or block.get("content") or ""
|
||||
if isinstance(value, str) and value:
|
||||
text_parts.append(value)
|
||||
return "".join(text_parts)
|
||||
|
||||
|
||||
def normalize_ai_message_to_parts(
|
||||
message: AIMessage | Any | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return user-visible assistant-ui parts for a final AI message.
|
||||
|
||||
We intentionally do not backfill provider ``thinking`` /
|
||||
``reasoning_content`` blocks here. If reasoning streamed live, the
|
||||
``AssistantContentBuilder`` already captured it. If it only exists in the
|
||||
final model payload, persisting it retroactively could expose content the
|
||||
UI never showed during the turn.
|
||||
"""
|
||||
if message is None:
|
||||
return []
|
||||
|
||||
text = _text_from_content(getattr(message, "content", None)).strip()
|
||||
if not text:
|
||||
return []
|
||||
return [{"type": "text", "text": text}]
|
||||
|
||||
|
||||
def last_ai_message(messages: Iterable[Any] | None) -> AIMessage | Any | None:
|
||||
if messages is None:
|
||||
return None
|
||||
for message in reversed(list(messages)):
|
||||
if isinstance(message, AIMessage):
|
||||
return message
|
||||
if getattr(message, "type", None) == "ai":
|
||||
return message
|
||||
return None
|
||||
|
||||
|
||||
def final_assistant_parts_from_messages(
|
||||
messages: Iterable[Any] | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
return normalize_ai_message_to_parts(last_ai_message(messages))
|
||||
|
||||
|
||||
def has_non_empty_text_part(parts: Iterable[dict[str, Any]]) -> bool:
|
||||
return any(
|
||||
part.get("type") == "text"
|
||||
and isinstance(part.get("text"), str)
|
||||
and bool(part.get("text", "").strip())
|
||||
for part in parts
|
||||
)
|
||||
|
||||
|
||||
def merge_streamed_and_final_parts(
|
||||
streamed_parts: list[dict[str, Any]],
|
||||
final_parts: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Use final-state text only when streaming captured no answer text."""
|
||||
if has_non_empty_text_part(streamed_parts):
|
||||
return streamed_parts
|
||||
if not has_non_empty_text_part(final_parts):
|
||||
return streamed_parts
|
||||
return [*streamed_parts, *final_parts]
|
||||
|
|
@ -16,6 +16,9 @@ from app.agents.chat.multi_agent_chat.main_agent.middleware.kb_persistence impor
|
|||
)
|
||||
from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
from app.tasks.chat.message_parts_normalizer import (
|
||||
final_assistant_parts_from_messages,
|
||||
)
|
||||
from app.tasks.chat.streaming.contract.file_contract import (
|
||||
contract_enforcement_active,
|
||||
evaluate_file_contract_outcome,
|
||||
|
|
@ -75,6 +78,9 @@ async def stream_agent_events(
|
|||
|
||||
state = await agent.aget_state(config)
|
||||
state_values = getattr(state, "values", {}) or {}
|
||||
result.final_message_parts = final_assistant_parts_from_messages(
|
||||
state_values.get("messages")
|
||||
)
|
||||
|
||||
# Safety net: if astream_events was cancelled before
|
||||
# KnowledgeBasePersistenceMiddleware.aafter_agent ran, any staged work
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from app.agents.chat.multi_agent_chat.main_agent.middleware.busy_mutex import (
|
|||
is_cancel_requested,
|
||||
)
|
||||
from app.agents.chat.runtime.errors import BusyError
|
||||
from app.services.llm_error_adapter import LLMErrorCategory, adapt_llm_exception
|
||||
|
||||
TURN_CANCELLING_INITIAL_DELAY_MS = 200
|
||||
TURN_CANCELLING_BACKOFF_FACTOR = 2
|
||||
|
|
@ -102,6 +103,9 @@ def _extract_provider_error_code(parsed: dict[str, Any] | None) -> int | None:
|
|||
|
||||
def is_provider_rate_limited(exc: BaseException) -> bool:
|
||||
"""Return True if the exception looks like an upstream HTTP 429 / rate limit."""
|
||||
if adapt_llm_exception(exc).category is LLMErrorCategory.RATE_LIMITED:
|
||||
return True
|
||||
|
||||
raw = str(exc)
|
||||
lowered = raw.lower()
|
||||
if "ratelimit" in type(exc).__name__.lower():
|
||||
|
|
@ -131,6 +135,85 @@ def is_provider_rate_limited(exc: BaseException) -> bool:
|
|||
)
|
||||
|
||||
|
||||
def _provider_error_extra(adapted: Any) -> dict[str, Any] | None:
|
||||
extra: dict[str, Any] = {"provider_error_category": adapted.category.value}
|
||||
if adapted.provider_status_code is not None:
|
||||
extra["provider_status_code"] = adapted.provider_status_code
|
||||
if adapted.provider_error_type:
|
||||
extra["provider_error_type"] = adapted.provider_error_type
|
||||
return extra
|
||||
|
||||
|
||||
def _classify_provider_exception(
|
||||
exc: Exception,
|
||||
) -> (
|
||||
tuple[str, str, Literal["info", "warn", "error"], bool, str, dict[str, Any] | None]
|
||||
| None
|
||||
):
|
||||
adapted = adapt_llm_exception(exc)
|
||||
|
||||
if adapted.category is LLMErrorCategory.RATE_LIMITED:
|
||||
return (
|
||||
"rate_limited",
|
||||
"RATE_LIMITED",
|
||||
"warn",
|
||||
True,
|
||||
"This model is temporarily rate-limited. Please try again in a few seconds or switch models.",
|
||||
_provider_error_extra(adapted),
|
||||
)
|
||||
|
||||
if adapted.category in {
|
||||
LLMErrorCategory.AUTH_FAILED,
|
||||
LLMErrorCategory.PERMISSION_DENIED,
|
||||
}:
|
||||
return (
|
||||
"model_auth_failed",
|
||||
"MODEL_AUTH_FAILED",
|
||||
"warn",
|
||||
True,
|
||||
"This model's API key is invalid or expired. Switch models, or update the API key.",
|
||||
_provider_error_extra(adapted),
|
||||
)
|
||||
|
||||
if adapted.category is LLMErrorCategory.MODEL_NOT_FOUND:
|
||||
return (
|
||||
"model_not_found",
|
||||
"MODEL_NOT_FOUND",
|
||||
"warn",
|
||||
True,
|
||||
"The selected model is unavailable or no longer exists. Switch to another model and try again.",
|
||||
_provider_error_extra(adapted),
|
||||
)
|
||||
|
||||
if adapted.category is LLMErrorCategory.CONTEXT_LIMIT:
|
||||
return (
|
||||
"model_context_limit",
|
||||
"MODEL_CONTEXT_LIMIT",
|
||||
"warn",
|
||||
True,
|
||||
"This request is too large for the selected model. Try a model with a larger context window or reduce the input.",
|
||||
_provider_error_extra(adapted),
|
||||
)
|
||||
|
||||
if adapted.category in {
|
||||
LLMErrorCategory.TIMEOUT,
|
||||
LLMErrorCategory.PROVIDER_UNAVAILABLE,
|
||||
LLMErrorCategory.BAD_GATEWAY,
|
||||
LLMErrorCategory.CONNECTION_FAILED,
|
||||
LLMErrorCategory.SERVER_ERROR,
|
||||
}:
|
||||
return (
|
||||
"model_provider_unavailable",
|
||||
"MODEL_PROVIDER_UNAVAILABLE",
|
||||
"warn",
|
||||
True,
|
||||
"The selected model provider is temporarily unavailable. Please try again or switch models.",
|
||||
_provider_error_extra(adapted),
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def classify_stream_exception(
|
||||
exc: Exception,
|
||||
*,
|
||||
|
|
@ -167,15 +250,9 @@ def classify_stream_exception(
|
|||
None,
|
||||
)
|
||||
|
||||
if is_provider_rate_limited(exc):
|
||||
return (
|
||||
"rate_limited",
|
||||
"RATE_LIMITED",
|
||||
"warn",
|
||||
True,
|
||||
"This model is temporarily rate-limited. Please try again in a few seconds or switch models.",
|
||||
None,
|
||||
)
|
||||
provider_classification = _classify_provider_exception(exc)
|
||||
if provider_classification is not None:
|
||||
return provider_classification
|
||||
|
||||
return (
|
||||
"server_error",
|
||||
|
|
|
|||
|
|
@ -80,7 +80,6 @@ async def _generate_title(
|
|||
from litellm import acompletion
|
||||
|
||||
from app.services.llm_router_service import LLMRouterService
|
||||
from app.services.provider_api_base import resolve_api_base
|
||||
from app.services.token_tracking_service import _turn_accumulator
|
||||
|
||||
# Excludes this turn's own assistant row (pre-written by
|
||||
|
|
@ -125,26 +124,12 @@ async def _generate_title(
|
|||
router = LLMRouterService.get_router()
|
||||
response = await router.acompletion(model="auto", messages=messages)
|
||||
else:
|
||||
# Apply the same ``api_base`` cascade chat / vision / image-gen
|
||||
# call sites use so we never inherit ``litellm.api_base``
|
||||
# (commonly set by ``AZURE_OPENAI_ENDPOINT``) when the chat
|
||||
# config itself ships an empty ``api_base``. Without this the
|
||||
# title-gen on an OpenRouter chat config would 404 against the
|
||||
# inherited Azure endpoint — see ``provider_api_base`` for the
|
||||
# same bug repro on the image-gen / vision paths.
|
||||
raw_model = getattr(llm, "model", "") or ""
|
||||
provider_prefix = raw_model.split("/", 1)[0] if "/" in raw_model else None
|
||||
provider_value = agent_config.provider if agent_config is not None else None
|
||||
title_api_base = resolve_api_base(
|
||||
provider=provider_value,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=getattr(llm, "api_base", None),
|
||||
)
|
||||
response = await acompletion(
|
||||
model=raw_model,
|
||||
messages=messages,
|
||||
api_key=getattr(llm, "api_key", None),
|
||||
api_base=title_api_base,
|
||||
api_base=getattr(llm, "api_base", None),
|
||||
)
|
||||
|
||||
usage_info = None
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ async def finalize_assistant_message(
|
|||
):
|
||||
return
|
||||
|
||||
from app.tasks.chat.message_parts_normalizer import merge_streamed_and_final_parts
|
||||
from app.tasks.chat.persistence import finalize_assistant_turn
|
||||
|
||||
builder_stats: dict[str, int] | None = None
|
||||
|
|
@ -74,6 +75,10 @@ async def finalize_assistant_message(
|
|||
"text": stream_result.accumulated_text or "",
|
||||
}
|
||||
]
|
||||
content_payload = merge_streamed_and_final_parts(
|
||||
content_payload,
|
||||
stream_result.final_message_parts,
|
||||
)
|
||||
|
||||
if builder_stats is not None:
|
||||
_perf_log.info(
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
"""Load an LLM + AgentConfig bundle for a given config id.
|
||||
|
||||
Handles both code paths uniformly:
|
||||
- ``config_id >= 0`` → database-backed ``NewLLMConfig`` row (per-user/per-space).
|
||||
- ``config_id < 0`` → YAML-defined global LLM config (built-in defaults).
|
||||
- ``config_id > 0`` → database-backed model-connection ``Model`` row.
|
||||
- ``config_id < 0`` → virtual global model materialized from YAML/OpenRouter.
|
||||
|
||||
Returns ``(llm, agent_config, error_message)``; on success ``error_message`` is
|
||||
``None``. The caller emits the friendly SSE error frame.
|
||||
|
|
@ -12,15 +12,78 @@ from __future__ import annotations
|
|||
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.agents.chat.runtime.llm_config import (
|
||||
AgentConfig,
|
||||
create_chat_litellm_from_agent_config,
|
||||
create_chat_litellm_from_config,
|
||||
load_agent_config,
|
||||
load_global_llm_config_by_id,
|
||||
SanitizedChatLiteLLM,
|
||||
)
|
||||
from app.config import config
|
||||
from app.db import Model, SearchSpace
|
||||
from app.services.model_capabilities import has_capability
|
||||
from app.services.model_resolver import to_litellm
|
||||
from app.services.token_tracking_service import register_model_usage_metadata
|
||||
|
||||
|
||||
def _agent_config_from_resolved(
|
||||
*,
|
||||
config_id: int,
|
||||
config_name: str | None,
|
||||
provider: str,
|
||||
model_name: str,
|
||||
api_key: str | None,
|
||||
api_base: str | None,
|
||||
litellm_params: dict | None,
|
||||
supports_image_input: bool,
|
||||
billing_tier: str = "free",
|
||||
) -> AgentConfig:
|
||||
return AgentConfig(
|
||||
provider=provider,
|
||||
model_name=model_name,
|
||||
api_key=api_key or "",
|
||||
api_base=api_base,
|
||||
custom_provider=None,
|
||||
litellm_params=litellm_params,
|
||||
config_id=config_id,
|
||||
config_name=config_name,
|
||||
is_auto_mode=False,
|
||||
billing_tier=billing_tier,
|
||||
is_premium=billing_tier == "premium",
|
||||
supports_image_input=supports_image_input,
|
||||
)
|
||||
|
||||
|
||||
async def _load_search_space(
|
||||
session: AsyncSession, search_space_id: int
|
||||
) -> SearchSpace | None:
|
||||
result = await session.execute(
|
||||
select(SearchSpace).where(SearchSpace.id == search_space_id)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
|
||||
async def _load_db_model(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
model_id: int,
|
||||
search_space: SearchSpace,
|
||||
) -> Model | None:
|
||||
result = await session.execute(
|
||||
select(Model)
|
||||
.options(selectinload(Model.connection))
|
||||
.where(Model.id == model_id, Model.enabled.is_(True))
|
||||
)
|
||||
model = result.scalars().first()
|
||||
if not model or not model.connection or not model.connection.enabled:
|
||||
return None
|
||||
conn = model.connection
|
||||
if conn.search_space_id is not None and conn.search_space_id != search_space.id:
|
||||
return None
|
||||
if conn.user_id is not None and conn.user_id != search_space.user_id:
|
||||
return None
|
||||
return model
|
||||
|
||||
|
||||
async def load_llm_bundle(
|
||||
|
|
@ -29,29 +92,89 @@ async def load_llm_bundle(
|
|||
config_id: int,
|
||||
search_space_id: int,
|
||||
) -> tuple[Any, AgentConfig | None, str | None]:
|
||||
if config_id >= 0:
|
||||
loaded_agent_config = await load_agent_config(
|
||||
session=session,
|
||||
config_id=config_id,
|
||||
search_space_id=search_space_id,
|
||||
search_space = await _load_search_space(session, search_space_id)
|
||||
if not search_space:
|
||||
return None, None, f"Search space {search_space_id} not found"
|
||||
|
||||
if config_id > 0:
|
||||
model = await _load_db_model(
|
||||
session,
|
||||
model_id=config_id,
|
||||
search_space=search_space,
|
||||
)
|
||||
if not loaded_agent_config:
|
||||
if not model or not has_capability(model, "chat"):
|
||||
return (
|
||||
None,
|
||||
None,
|
||||
f"Failed to load NewLLMConfig with id {config_id}",
|
||||
f"Failed to load chat model with id {config_id}",
|
||||
)
|
||||
model_string, litellm_kwargs = to_litellm(model.connection, model.model_id)
|
||||
display_name = model.display_name or model.model_id
|
||||
provider = model.connection.provider or ""
|
||||
register_model_usage_metadata(
|
||||
model=model_string,
|
||||
model_ref=f"db:{model.id}",
|
||||
model_id=model.model_id,
|
||||
display_name=display_name,
|
||||
provider=provider,
|
||||
)
|
||||
agent_config = _agent_config_from_resolved(
|
||||
config_id=config_id,
|
||||
config_name=display_name,
|
||||
provider=provider,
|
||||
model_name=model.model_id,
|
||||
api_key=model.connection.api_key,
|
||||
api_base=model.connection.base_url,
|
||||
litellm_params=(model.connection.extra or {}).get("litellm_params"),
|
||||
supports_image_input=has_capability(model, "vision"),
|
||||
billing_tier="free",
|
||||
)
|
||||
return (
|
||||
create_chat_litellm_from_agent_config(loaded_agent_config),
|
||||
loaded_agent_config,
|
||||
SanitizedChatLiteLLM(model=model_string, **litellm_kwargs),
|
||||
agent_config,
|
||||
None,
|
||||
)
|
||||
|
||||
loaded_llm_config = load_global_llm_config_by_id(config_id)
|
||||
if not loaded_llm_config:
|
||||
return None, None, f"Failed to load LLM config with id {config_id}"
|
||||
return (
|
||||
create_chat_litellm_from_config(loaded_llm_config),
|
||||
AgentConfig.from_yaml_config(loaded_llm_config),
|
||||
global_model = next(
|
||||
(m for m in config.GLOBAL_MODELS if m.get("id") == config_id), None
|
||||
)
|
||||
if not global_model or not has_capability(global_model, "chat"):
|
||||
return None, None, f"Failed to load global chat model with id {config_id}"
|
||||
global_connection = next(
|
||||
(
|
||||
c
|
||||
for c in config.GLOBAL_CONNECTIONS
|
||||
if c.get("id") == global_model.get("connection_id")
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not global_connection:
|
||||
return None, None, f"Failed to load global connection for model {config_id}"
|
||||
model_string, litellm_kwargs = to_litellm(
|
||||
global_connection, global_model["model_id"]
|
||||
)
|
||||
display_name = global_model.get("display_name") or global_model.get("model_id")
|
||||
provider = global_connection.get("provider") or ""
|
||||
register_model_usage_metadata(
|
||||
model=model_string,
|
||||
model_ref=f"global:{config_id}",
|
||||
model_id=global_model["model_id"],
|
||||
display_name=display_name,
|
||||
provider=provider,
|
||||
)
|
||||
agent_config = _agent_config_from_resolved(
|
||||
config_id=config_id,
|
||||
config_name=display_name,
|
||||
provider=provider,
|
||||
model_name=global_model["model_id"],
|
||||
api_key=global_connection.get("api_key"),
|
||||
api_base=global_connection.get("base_url"),
|
||||
litellm_params=(global_connection.get("extra") or {}).get("litellm_params"),
|
||||
supports_image_input=has_capability(global_model, "vision"),
|
||||
billing_tier=str(global_model.get("billing_tier", "free")).lower(),
|
||||
)
|
||||
return (
|
||||
SanitizedChatLiteLLM(model=model_string, **litellm_kwargs),
|
||||
agent_config,
|
||||
None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -35,3 +35,7 @@ class StreamResult:
|
|||
# (``StreamResult`` is logged in some error branches) from dumping a
|
||||
# potentially-large parts list.
|
||||
content_builder: Any | None = field(default=None, repr=False)
|
||||
# User-visible assistant message parts derived from the final LangGraph
|
||||
# state. Used after streaming completes as a provider-agnostic persistence
|
||||
# backfill when no text chunks reached the live stream.
|
||||
final_message_parts: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
|
|
|||
|
|
@ -18,6 +18,11 @@ from sqlalchemy import select
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.tasks.chat.llm_history_normalizer import (
|
||||
assistant_content_to_llm_text,
|
||||
user_content_to_llm_content,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.db import ChatVisibility
|
||||
|
||||
|
|
@ -95,17 +100,28 @@ async def bootstrap_history_from_db(
|
|||
langchain_messages: list[HumanMessage | AIMessage] = []
|
||||
|
||||
for msg in db_messages:
|
||||
text_content = extract_text_content(msg.content)
|
||||
if not text_content:
|
||||
continue
|
||||
if msg.role == "user":
|
||||
user_content = user_content_to_llm_content(
|
||||
msg.content,
|
||||
allow_images=False,
|
||||
)
|
||||
if not user_content:
|
||||
continue
|
||||
if is_shared:
|
||||
author_name = (
|
||||
msg.author.display_name if msg.author else None
|
||||
) or "A team member"
|
||||
text_content = f"**[{author_name}]:** {text_content}"
|
||||
langchain_messages.append(HumanMessage(content=text_content))
|
||||
if isinstance(user_content, str):
|
||||
user_content = f"**[{author_name}]:** {user_content}"
|
||||
elif user_content and user_content[0].get("type") == "text":
|
||||
user_content[0] = {
|
||||
**user_content[0],
|
||||
"text": f"**[{author_name}]:** {user_content[0].get('text', '')}",
|
||||
}
|
||||
langchain_messages.append(HumanMessage(content=user_content))
|
||||
elif msg.role == "assistant":
|
||||
langchain_messages.append(AIMessage(content=text_content))
|
||||
assistant_text = assistant_content_to_llm_text(msg.content)
|
||||
if assistant_text:
|
||||
langchain_messages.append(AIMessage(content=assistant_text))
|
||||
|
||||
return langchain_messages
|
||||
|
|
|
|||
|
|
@ -55,7 +55,6 @@ from app.services.openrouter_integration_service import ( # noqa: E402
|
|||
_OPENROUTER_DYNAMIC_MARKER,
|
||||
OpenRouterIntegrationService,
|
||||
)
|
||||
from app.services.provider_api_base import resolve_api_base # noqa: E402
|
||||
from app.services.provider_capabilities import ( # noqa: E402
|
||||
derive_supports_image_input,
|
||||
is_known_text_only_chat_model,
|
||||
|
|
@ -154,13 +153,13 @@ def _probe_chat_capability(cfg: dict) -> tuple[bool, str]:
|
|||
litellm_params.get("base_model") if isinstance(litellm_params, dict) else None
|
||||
)
|
||||
cap = derive_supports_image_input(
|
||||
provider=cfg.get("provider"),
|
||||
provider=cfg.get("litellm_provider"),
|
||||
model_name=cfg.get("model_name"),
|
||||
base_model=base_model,
|
||||
custom_provider=cfg.get("custom_provider"),
|
||||
)
|
||||
block = is_known_text_only_chat_model(
|
||||
provider=cfg.get("provider"),
|
||||
provider=cfg.get("litellm_provider"),
|
||||
model_name=cfg.get("model_name"),
|
||||
base_model=base_model,
|
||||
custom_provider=cfg.get("custom_provider"),
|
||||
|
|
@ -179,11 +178,7 @@ def _probe_chat_capability(cfg: dict) -> tuple[bool, str]:
|
|||
def _build_chat_model_string(cfg: dict) -> str:
|
||||
if cfg.get("custom_provider"):
|
||||
return f"{cfg['custom_provider']}/{cfg['model_name']}"
|
||||
from app.services.provider_capabilities import _PROVIDER_PREFIX_MAP
|
||||
|
||||
prefix = _PROVIDER_PREFIX_MAP.get(
|
||||
(cfg.get("provider") or "").upper(), (cfg.get("provider") or "").lower()
|
||||
)
|
||||
prefix = cfg.get("litellm_provider") or "openai"
|
||||
return f"{prefix}/{cfg['model_name']}"
|
||||
|
||||
|
||||
|
|
@ -195,11 +190,6 @@ def _build_chat_model_string(cfg: dict) -> str:
|
|||
async def _live_chat_image_call(cfg: dict) -> tuple[bool, str]:
|
||||
"""Send a 1x1 PNG + `reply with one word: ok` to the chat config."""
|
||||
model_string = _build_chat_model_string(cfg)
|
||||
api_base = resolve_api_base(
|
||||
provider=cfg.get("provider"),
|
||||
provider_prefix=model_string.split("/", 1)[0],
|
||||
config_api_base=cfg.get("api_base") or None,
|
||||
)
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": model_string,
|
||||
"api_key": cfg.get("api_key"),
|
||||
|
|
@ -218,8 +208,8 @@ async def _live_chat_image_call(cfg: dict) -> tuple[bool, str]:
|
|||
"max_tokens": 16,
|
||||
"timeout": 60,
|
||||
}
|
||||
if api_base:
|
||||
kwargs["api_base"] = api_base
|
||||
if cfg.get("api_base"):
|
||||
kwargs["api_base"] = cfg["api_base"]
|
||||
if cfg.get("litellm_params"):
|
||||
# Strip pricing keys — they're tracking-only and confuse some
|
||||
# provider validators (e.g. azure/openai reject unknown kwargs
|
||||
|
|
@ -257,20 +247,11 @@ _IMAGE_GEN_PROMPTS: tuple[str, ...] = (
|
|||
|
||||
async def _live_image_gen_call(cfg: dict) -> tuple[bool, str]:
|
||||
"""Generate one tiny image to verify the deployment is reachable."""
|
||||
from app.services.provider_capabilities import _PROVIDER_PREFIX_MAP
|
||||
|
||||
if cfg.get("custom_provider"):
|
||||
prefix = cfg["custom_provider"]
|
||||
else:
|
||||
prefix = _PROVIDER_PREFIX_MAP.get(
|
||||
(cfg.get("provider") or "").upper(), (cfg.get("provider") or "").lower()
|
||||
)
|
||||
prefix = cfg.get("litellm_provider") or "openai"
|
||||
model_string = f"{prefix}/{cfg['model_name']}"
|
||||
api_base = resolve_api_base(
|
||||
provider=cfg.get("provider"),
|
||||
provider_prefix=prefix,
|
||||
config_api_base=cfg.get("api_base") or None,
|
||||
)
|
||||
base_kwargs: dict[str, Any] = {
|
||||
"model": model_string,
|
||||
"api_key": cfg.get("api_key"),
|
||||
|
|
@ -278,8 +259,8 @@ async def _live_image_gen_call(cfg: dict) -> tuple[bool, str]:
|
|||
"size": "1024x1024",
|
||||
"timeout": 120,
|
||||
}
|
||||
if api_base:
|
||||
base_kwargs["api_base"] = api_base
|
||||
if cfg.get("api_base"):
|
||||
base_kwargs["api_base"] = cfg["api_base"]
|
||||
if cfg.get("api_version"):
|
||||
base_kwargs["api_version"] = cfg["api_version"]
|
||||
if cfg.get("litellm_params"):
|
||||
|
|
@ -349,31 +330,6 @@ async def probe_chat_configs(report: Report, *, live: bool) -> None:
|
|||
report.add(result)
|
||||
|
||||
|
||||
async def probe_vision_configs(report: Report, *, live: bool) -> None:
|
||||
print("\n[vision configs from global_vision_llm_configs (YAML-static)]")
|
||||
for cfg in config.GLOBAL_VISION_LLM_CONFIGS:
|
||||
if _is_or_dynamic(cfg):
|
||||
continue
|
||||
result = ProbeResult(
|
||||
label=str(cfg.get("name") or cfg.get("model_name")),
|
||||
surface="vision",
|
||||
config_id=cfg.get("id"),
|
||||
)
|
||||
# For vision configs, capability is implied — they're in the
|
||||
# dedicated vision pool. Run the same resolver to flag any
|
||||
# surprise disagreement.
|
||||
cap_ok, cap_note = _probe_chat_capability(cfg)
|
||||
result.capability_ok = cap_ok
|
||||
result.capability_note = cap_note
|
||||
if live:
|
||||
t0 = time.perf_counter()
|
||||
ok, note = await _live_chat_image_call(cfg)
|
||||
result.live_ok = ok
|
||||
result.live_note = note
|
||||
result.duration_s = time.perf_counter() - t0
|
||||
report.add(result)
|
||||
|
||||
|
||||
async def probe_image_gen_configs(report: Report, *, live: bool) -> None:
|
||||
print(
|
||||
"\n[image generation configs from global_image_generation_configs (YAML-static)]"
|
||||
|
|
@ -399,7 +355,7 @@ async def probe_image_gen_configs(report: Report, *, live: bool) -> None:
|
|||
|
||||
|
||||
async def probe_openrouter_catalog(report: Report, *, live: bool) -> None:
|
||||
"""Sample one chat (vision-capable), one vision, one image-gen model
|
||||
"""Sample chat/vision-capable and image-gen models
|
||||
from the live OpenRouter catalogue. Doesn't iterate the full pool
|
||||
(would be hundreds of probes); just validates the integration end-
|
||||
to-end on a representative model from each surface."""
|
||||
|
|
@ -424,9 +380,6 @@ async def probe_openrouter_catalog(report: Report, *, live: bool) -> None:
|
|||
for c in config.GLOBAL_LLM_CONFIGS
|
||||
if c.get("provider") == "OPENROUTER" and c.get("supports_image_input")
|
||||
]
|
||||
or_vision = [
|
||||
c for c in config.GLOBAL_VISION_LLM_CONFIGS if c.get("provider") == "OPENROUTER"
|
||||
]
|
||||
or_image_gen = [
|
||||
c for c in config.GLOBAL_IMAGE_GEN_CONFIGS if c.get("provider") == "OPENROUTER"
|
||||
]
|
||||
|
|
@ -446,11 +399,6 @@ async def probe_openrouter_catalog(report: Report, *, live: bool) -> None:
|
|||
("or-chat", _pick_first(or_chat, "anthropic/claude")),
|
||||
("or-chat", _pick_first(or_chat, "google/gemini-2.5-flash")),
|
||||
]
|
||||
vision_picks = [
|
||||
("or-vision", _pick_first(or_vision, "openai/gpt-4o")),
|
||||
("or-vision", _pick_first(or_vision, "anthropic/claude")),
|
||||
("or-vision", _pick_first(or_vision, "google/gemini-2.5-flash")),
|
||||
]
|
||||
image_picks = [
|
||||
("or-image", _pick_first(or_image_gen, "google/gemini-2.5-flash-image")),
|
||||
# OpenRouter publishes OpenAI image models as ``openai/gpt-5-image*``
|
||||
|
|
@ -460,11 +408,11 @@ async def probe_openrouter_catalog(report: Report, *, live: bool) -> None:
|
|||
]
|
||||
|
||||
print(
|
||||
f" catalog: chat={len(or_chat)} vision={len(or_vision)} image_gen={len(or_image_gen)} "
|
||||
f" catalog: chat_vision={len(or_chat)} image_gen={len(or_image_gen)} "
|
||||
f"(service initialized={service.is_initialized() if hasattr(service, 'is_initialized') else 'n/a'})"
|
||||
)
|
||||
|
||||
for surface, picked in chat_picks + vision_picks + image_picks:
|
||||
for surface, picked in chat_picks + image_picks:
|
||||
if not picked:
|
||||
report.add(
|
||||
ProbeResult(
|
||||
|
|
@ -505,7 +453,6 @@ async def probe_openrouter_catalog(report: Report, *, live: bool) -> None:
|
|||
async def main(args: argparse.Namespace) -> int:
|
||||
print("Loaded global configs:")
|
||||
print(f" chat: {len(config.GLOBAL_LLM_CONFIGS)} entries")
|
||||
print(f" vision: {len(config.GLOBAL_VISION_LLM_CONFIGS)} entries")
|
||||
print(f" image-gen: {len(config.GLOBAL_IMAGE_GEN_CONFIGS)} entries")
|
||||
print(f" OR settings present: {bool(config.OPENROUTER_INTEGRATION_SETTINGS)}")
|
||||
|
||||
|
|
@ -526,8 +473,6 @@ async def main(args: argparse.Namespace) -> int:
|
|||
report = Report()
|
||||
if not args.skip_chat:
|
||||
await probe_chat_configs(report, live=args.live)
|
||||
if not args.skip_vision:
|
||||
await probe_vision_configs(report, live=args.live)
|
||||
if not args.skip_image_gen:
|
||||
await probe_image_gen_configs(report, live=args.live)
|
||||
if not args.skip_openrouter:
|
||||
|
|
@ -547,7 +492,6 @@ def _parse_args() -> argparse.Namespace:
|
|||
)
|
||||
parser.set_defaults(live=True)
|
||||
parser.add_argument("--skip-chat", action="store_true")
|
||||
parser.add_argument("--skip-vision", action="store_true")
|
||||
parser.add_argument("--skip-image-gen", action="store_true")
|
||||
parser.add_argument("--skip-openrouter", action="store_true")
|
||||
return parser.parse_args()
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@
|
|||
# so the resolved auto-pin id is never sent to a real LLM provider.
|
||||
# The values below only need to pass
|
||||
# auto_model_pin_service._is_usable_global_config()
|
||||
# which requires id / model_name / provider / api_key all truthy.
|
||||
# which requires id / model_name / litellm_provider / api_key all truthy.
|
||||
#
|
||||
# Why TWO entries (premium + free):
|
||||
# auto_model_pin_service.resolve_or_get_pinned_llm_config_id() splits
|
||||
|
|
@ -44,9 +44,10 @@ global_llm_configs:
|
|||
anonymous_enabled: false
|
||||
seo_enabled: false
|
||||
quality_score: 1.0
|
||||
provider: "OPENAI"
|
||||
litellm_provider: "openai"
|
||||
model_name: "fake-e2e-model-premium"
|
||||
api_key: "fake-e2e-api-key-not-for-production"
|
||||
api_base: "https://api.openai.com/v1"
|
||||
supports_image_input: false
|
||||
quota_reserve_tokens: 1024
|
||||
rpm: 1000
|
||||
|
|
@ -60,9 +61,10 @@ global_llm_configs:
|
|||
anonymous_enabled: false
|
||||
seo_enabled: false
|
||||
quality_score: 1.0
|
||||
provider: "OPENAI"
|
||||
litellm_provider: "openai"
|
||||
model_name: "fake-e2e-model-free"
|
||||
api_key: "fake-e2e-api-key-not-for-production"
|
||||
api_base: "https://api.openai.com/v1"
|
||||
supports_image_input: false
|
||||
quota_reserve_tokens: 1024
|
||||
rpm: 1000
|
||||
|
|
|
|||
|
|
@ -0,0 +1,39 @@
|
|||
"""Regression tests for model-boundary message sanitization."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from app.agents.chat.runtime.llm_config import _sanitize_messages
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_sanitize_messages_strips_provider_specific_thinking_blocks() -> None:
|
||||
original = AIMessage(
|
||||
content=[
|
||||
{"type": "thinking", "thinking": "private reasoning"},
|
||||
{"type": "text", "text": "visible answer"},
|
||||
]
|
||||
)
|
||||
|
||||
sanitized = _sanitize_messages([original])
|
||||
|
||||
assert sanitized[0].content == "visible answer"
|
||||
assert original.content == [
|
||||
{"type": "thinking", "thinking": "private reasoning"},
|
||||
{"type": "text", "text": "visible answer"},
|
||||
]
|
||||
|
||||
|
||||
def test_sanitize_messages_sets_tool_only_ai_content_to_none() -> None:
|
||||
message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[{"name": "search", "args": {"q": "x"}, "id": "call_1"}],
|
||||
)
|
||||
|
||||
sanitized = _sanitize_messages([message])
|
||||
|
||||
assert sanitized[0].content is None
|
||||
assert message.content == ""
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
"""Lock the runtime model-policy backstop in ``build_dependencies``.
|
||||
|
||||
Automations resolve their LLM from the *captured* ``agent_llm_id`` snapshot (so
|
||||
Automations resolve their LLM from the *captured* ``chat_model_id`` snapshot (so
|
||||
runs are insulated from later chat/search-space model changes), and the model
|
||||
policy is re-checked at run time so a captured model that is no longer billable
|
||||
fails the run clearly. When no snapshot is present, resolution falls back to the
|
||||
|
|
@ -45,10 +45,10 @@ def patched_side_effects(monkeypatch: pytest.MonkeyPatch):
|
|||
return None
|
||||
|
||||
|
||||
async def test_build_dependencies_resolves_captured_agent_llm_id(
|
||||
async def test_build_dependencies_resolves_captured_chat_model_id(
|
||||
monkeypatch: pytest.MonkeyPatch, patched_side_effects
|
||||
) -> None:
|
||||
"""The bundle loads with the *captured* ``agent_llm_id``, not the live search space."""
|
||||
"""The bundle loads with the *captured* ``chat_model_id``, not the live search space."""
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
async def _fake_load(_session, *, config_id, search_space_id):
|
||||
|
|
@ -67,13 +67,13 @@ async def test_build_dependencies_resolves_captured_agent_llm_id(
|
|||
lambda _ss: pytest.fail("search-space policy should not run on captured path"),
|
||||
)
|
||||
|
||||
search_space = SimpleNamespace(agent_llm_id=-99)
|
||||
search_space = SimpleNamespace(chat_model_id=-99)
|
||||
result = await build_dependencies(
|
||||
session=_FakeSession(search_space),
|
||||
search_space_id=42,
|
||||
agent_llm_id=-7,
|
||||
image_generation_config_id=5,
|
||||
vision_llm_config_id=-1,
|
||||
chat_model_id=-7,
|
||||
image_gen_model_id=5,
|
||||
vision_model_id=-1,
|
||||
)
|
||||
|
||||
assert captured == {"config_id": -7, "search_space_id": 42}
|
||||
|
|
@ -98,17 +98,17 @@ async def test_build_dependencies_validates_captured_ids(
|
|||
monkeypatch.setattr(deps_mod, "load_llm_bundle", _fake_load)
|
||||
|
||||
await build_dependencies(
|
||||
session=_FakeSession(SimpleNamespace(agent_llm_id=0)),
|
||||
session=_FakeSession(SimpleNamespace(chat_model_id=0)),
|
||||
search_space_id=42,
|
||||
agent_llm_id=-7,
|
||||
image_generation_config_id=5,
|
||||
vision_llm_config_id=-1,
|
||||
chat_model_id=-7,
|
||||
image_gen_model_id=5,
|
||||
vision_model_id=-1,
|
||||
)
|
||||
|
||||
assert seen == {
|
||||
"agent_llm_id": -7,
|
||||
"image_generation_config_id": 5,
|
||||
"vision_llm_config_id": -1,
|
||||
"chat_model_id": -7,
|
||||
"image_gen_model_id": 5,
|
||||
"vision_model_id": -1,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -119,7 +119,7 @@ async def test_build_dependencies_raises_on_captured_policy_violation(
|
|||
|
||||
def _raise(**_kw):
|
||||
raise AutomationModelPolicyError(
|
||||
[{"kind": "image", "config_id": -2, "reason": "free model"}]
|
||||
[{"kind": "image", "model_id": -2, "reason": "free model"}]
|
||||
)
|
||||
|
||||
monkeypatch.setattr(deps_mod, "assert_models_billable", _raise)
|
||||
|
|
@ -131,11 +131,11 @@ async def test_build_dependencies_raises_on_captured_policy_violation(
|
|||
|
||||
with pytest.raises(DependencyError):
|
||||
await build_dependencies(
|
||||
session=_FakeSession(SimpleNamespace(agent_llm_id=-7)),
|
||||
session=_FakeSession(SimpleNamespace(chat_model_id=-7)),
|
||||
search_space_id=42,
|
||||
agent_llm_id=-7,
|
||||
image_generation_config_id=-2,
|
||||
vision_llm_config_id=-1,
|
||||
chat_model_id=-7,
|
||||
image_gen_model_id=-2,
|
||||
vision_model_id=-1,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -157,7 +157,7 @@ async def test_build_dependencies_falls_back_to_search_space(
|
|||
lambda **_kw: pytest.fail("captured policy should not run on fallback path"),
|
||||
)
|
||||
|
||||
search_space = SimpleNamespace(agent_llm_id=-7)
|
||||
search_space = SimpleNamespace(chat_model_id=-7)
|
||||
result = await build_dependencies(
|
||||
session=_FakeSession(search_space), search_space_id=42
|
||||
)
|
||||
|
|
|
|||
|
|
@ -28,9 +28,9 @@ def _run() -> SimpleNamespace:
|
|||
def test_build_action_ctx_propagates_captured_models() -> None:
|
||||
"""``definition.models`` flows onto the ActionContext model fields."""
|
||||
models = AutomationModels(
|
||||
agent_llm_id=-1,
|
||||
image_generation_config_id=5,
|
||||
vision_llm_config_id=-1,
|
||||
chat_model_id=-1,
|
||||
image_gen_model_id=5,
|
||||
vision_model_id=-1,
|
||||
)
|
||||
ctx = _build_action_ctx(
|
||||
cast(AsyncSession, None),
|
||||
|
|
@ -40,9 +40,9 @@ def test_build_action_ctx_propagates_captured_models() -> None:
|
|||
)
|
||||
|
||||
assert ctx.search_space_id == 42
|
||||
assert ctx.agent_llm_id == -1
|
||||
assert ctx.image_generation_config_id == 5
|
||||
assert ctx.vision_llm_config_id == -1
|
||||
assert ctx.chat_model_id == -1
|
||||
assert ctx.image_gen_model_id == 5
|
||||
assert ctx.vision_model_id == -1
|
||||
|
||||
|
||||
def test_build_action_ctx_none_models_leaves_fields_none() -> None:
|
||||
|
|
@ -54,6 +54,6 @@ def test_build_action_ctx_none_models_leaves_fields_none() -> None:
|
|||
None,
|
||||
)
|
||||
|
||||
assert ctx.agent_llm_id is None
|
||||
assert ctx.image_generation_config_id is None
|
||||
assert ctx.vision_llm_config_id is None
|
||||
assert ctx.chat_model_id is None
|
||||
assert ctx.image_gen_model_id is None
|
||||
assert ctx.vision_model_id is None
|
||||
|
|
|
|||
|
|
@ -40,24 +40,24 @@ def test_automation_definition_models_round_trip() -> None:
|
|||
name="Daily digest",
|
||||
plan=[PlanStep(step_id="s1", action="agent_task")],
|
||||
models=AutomationModels(
|
||||
agent_llm_id=-1,
|
||||
image_generation_config_id=5,
|
||||
vision_llm_config_id=-1,
|
||||
chat_model_id=-1,
|
||||
image_gen_model_id=5,
|
||||
vision_model_id=-1,
|
||||
),
|
||||
)
|
||||
|
||||
dumped = definition.model_dump(mode="json", by_alias=True)
|
||||
assert dumped["models"] == {
|
||||
"agent_llm_id": -1,
|
||||
"image_generation_config_id": 5,
|
||||
"vision_llm_config_id": -1,
|
||||
"chat_model_id": -1,
|
||||
"image_gen_model_id": 5,
|
||||
"vision_model_id": -1,
|
||||
}
|
||||
|
||||
restored = AutomationDefinition.model_validate(dumped)
|
||||
assert restored.models is not None
|
||||
assert restored.models.agent_llm_id == -1
|
||||
assert restored.models.image_generation_config_id == 5
|
||||
assert restored.models.vision_llm_config_id == -1
|
||||
assert restored.models.chat_model_id == -1
|
||||
assert restored.models.image_gen_model_id == 5
|
||||
assert restored.models.vision_model_id == -1
|
||||
|
||||
|
||||
def test_automation_definition_rejects_unknown_top_level_field() -> None:
|
||||
|
|
|
|||
|
|
@ -64,12 +64,12 @@ async def test_assert_models_billable_raises_422_on_violation(
|
|||
|
||||
def _raise(_ss):
|
||||
raise AutomationModelPolicyError(
|
||||
[{"kind": "llm", "config_id": 0, "reason": "Auto mode"}]
|
||||
[{"kind": "llm", "model_id": 0, "reason": "Auto mode"}]
|
||||
)
|
||||
|
||||
monkeypatch.setattr(automation_mod, "assert_automation_models_billable", _raise)
|
||||
|
||||
service = _service(SimpleNamespace(agent_llm_id=0))
|
||||
service = _service(SimpleNamespace(chat_model_id=0))
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service._assert_models_billable(1)
|
||||
|
||||
|
|
@ -99,7 +99,7 @@ async def test_assert_models_billable_returns_search_space_when_ok(
|
|||
automation_mod, "assert_automation_models_billable", lambda _ss: None
|
||||
)
|
||||
|
||||
search_space = SimpleNamespace(agent_llm_id=-1)
|
||||
search_space = SimpleNamespace(chat_model_id=-1)
|
||||
service = _service(search_space)
|
||||
assert await service._assert_models_billable(1) is search_space
|
||||
|
||||
|
|
@ -123,9 +123,9 @@ async def test_create_injects_captured_models_from_search_space(
|
|||
monkeypatch.setattr(AutomationService, "_get_with_triggers_or_raise", _return_added)
|
||||
|
||||
search_space = SimpleNamespace(
|
||||
agent_llm_id=-1,
|
||||
image_generation_config_id=5,
|
||||
vision_llm_config_id=-1,
|
||||
chat_model_id=-1,
|
||||
image_gen_model_id=5,
|
||||
vision_model_id=-1,
|
||||
)
|
||||
service = _service(search_space)
|
||||
payload = AutomationCreate(
|
||||
|
|
@ -137,9 +137,9 @@ async def test_create_injects_captured_models_from_search_space(
|
|||
automation = await service.create(payload)
|
||||
|
||||
assert automation.definition["models"] == {
|
||||
"agent_llm_id": -1,
|
||||
"image_generation_config_id": 5,
|
||||
"vision_llm_config_id": -1,
|
||||
"chat_model_id": -1,
|
||||
"image_gen_model_id": 5,
|
||||
"vision_model_id": -1,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -162,9 +162,9 @@ async def test_create_treats_unset_prefs_as_auto_zero(
|
|||
monkeypatch.setattr(AutomationService, "_get_with_triggers_or_raise", _return_added)
|
||||
|
||||
search_space = SimpleNamespace(
|
||||
agent_llm_id=None,
|
||||
image_generation_config_id=None,
|
||||
vision_llm_config_id=None,
|
||||
chat_model_id=None,
|
||||
image_gen_model_id=None,
|
||||
vision_model_id=None,
|
||||
)
|
||||
service = _service(search_space)
|
||||
payload = AutomationCreate(search_space_id=1, name="A", definition=_definition())
|
||||
|
|
@ -172,9 +172,9 @@ async def test_create_treats_unset_prefs_as_auto_zero(
|
|||
automation = await service.create(payload)
|
||||
|
||||
assert automation.definition["models"] == {
|
||||
"agent_llm_id": 0,
|
||||
"image_generation_config_id": 0,
|
||||
"vision_llm_config_id": 0,
|
||||
"chat_model_id": 0,
|
||||
"image_gen_model_id": 0,
|
||||
"vision_model_id": 0,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -195,11 +195,11 @@ async def test_create_honors_selected_models_when_provided(
|
|||
)
|
||||
validated: dict[str, Any] = {}
|
||||
|
||||
def _assert_ok(*, agent_llm_id, image_generation_config_id, vision_llm_config_id):
|
||||
def _assert_ok(*, chat_model_id, image_gen_model_id, vision_model_id):
|
||||
validated["ids"] = (
|
||||
agent_llm_id,
|
||||
image_generation_config_id,
|
||||
vision_llm_config_id,
|
||||
chat_model_id,
|
||||
image_gen_model_id,
|
||||
vision_model_id,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(automation_mod, "assert_models_billable", _assert_ok)
|
||||
|
|
@ -213,15 +213,15 @@ async def test_create_honors_selected_models_when_provided(
|
|||
monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize)
|
||||
monkeypatch.setattr(AutomationService, "_get_with_triggers_or_raise", _return_added)
|
||||
|
||||
service = _service(SimpleNamespace(agent_llm_id=-99))
|
||||
service = _service(SimpleNamespace(chat_model_id=-99))
|
||||
payload = AutomationCreate(
|
||||
search_space_id=1,
|
||||
name="A",
|
||||
definition=_definition(
|
||||
models=AutomationModels(
|
||||
agent_llm_id=-1,
|
||||
image_generation_config_id=7,
|
||||
vision_llm_config_id=-2,
|
||||
chat_model_id=-1,
|
||||
image_gen_model_id=7,
|
||||
vision_model_id=-2,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
|
@ -230,9 +230,9 @@ async def test_create_honors_selected_models_when_provided(
|
|||
|
||||
assert validated["ids"] == (-1, 7, -2)
|
||||
assert automation.definition["models"] == {
|
||||
"agent_llm_id": -1,
|
||||
"image_generation_config_id": 7,
|
||||
"vision_llm_config_id": -2,
|
||||
"chat_model_id": -1,
|
||||
"image_gen_model_id": 7,
|
||||
"vision_model_id": -2,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -241,9 +241,9 @@ async def test_create_rejects_unbillable_selected_models(
|
|||
) -> None:
|
||||
"""A non-billable explicit selection maps the policy error to HTTP 422."""
|
||||
|
||||
def _raise(*, agent_llm_id, image_generation_config_id, vision_llm_config_id):
|
||||
def _raise(*, chat_model_id, image_gen_model_id, vision_model_id):
|
||||
raise AutomationModelPolicyError(
|
||||
[{"kind": "llm", "config_id": -3, "reason": "free model"}]
|
||||
[{"kind": "llm", "model_id": -3, "reason": "free model"}]
|
||||
)
|
||||
|
||||
monkeypatch.setattr(automation_mod, "assert_models_billable", _raise)
|
||||
|
|
@ -253,15 +253,15 @@ async def test_create_rejects_unbillable_selected_models(
|
|||
|
||||
monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize)
|
||||
|
||||
service = _service(SimpleNamespace(agent_llm_id=-3))
|
||||
service = _service(SimpleNamespace(chat_model_id=-3))
|
||||
payload = AutomationCreate(
|
||||
search_space_id=1,
|
||||
name="A",
|
||||
definition=_definition(
|
||||
models=AutomationModels(
|
||||
agent_llm_id=-3,
|
||||
image_generation_config_id=7,
|
||||
vision_llm_config_id=-2,
|
||||
chat_model_id=-3,
|
||||
image_gen_model_id=7,
|
||||
vision_model_id=-2,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
|
@ -277,9 +277,9 @@ async def test_update_preserves_captured_models(
|
|||
) -> None:
|
||||
"""A definition edit carries over the previously captured ``models``."""
|
||||
captured = {
|
||||
"agent_llm_id": -1,
|
||||
"image_generation_config_id": 5,
|
||||
"vision_llm_config_id": -1,
|
||||
"chat_model_id": -1,
|
||||
"image_gen_model_id": 5,
|
||||
"vision_model_id": -1,
|
||||
}
|
||||
existing = SimpleNamespace(
|
||||
search_space_id=1,
|
||||
|
|
@ -318,20 +318,20 @@ async def test_update_honors_changed_models_when_valid(
|
|||
"name": "A",
|
||||
"plan": [],
|
||||
"models": {
|
||||
"agent_llm_id": -1,
|
||||
"image_generation_config_id": 5,
|
||||
"vision_llm_config_id": -1,
|
||||
"chat_model_id": -1,
|
||||
"image_gen_model_id": 5,
|
||||
"vision_model_id": -1,
|
||||
},
|
||||
},
|
||||
version=3,
|
||||
)
|
||||
validated: dict[str, Any] = {}
|
||||
|
||||
def _assert_ok(*, agent_llm_id, image_generation_config_id, vision_llm_config_id):
|
||||
def _assert_ok(*, chat_model_id, image_gen_model_id, vision_model_id):
|
||||
validated["ids"] = (
|
||||
agent_llm_id,
|
||||
image_generation_config_id,
|
||||
vision_llm_config_id,
|
||||
chat_model_id,
|
||||
image_gen_model_id,
|
||||
vision_model_id,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(automation_mod, "assert_models_billable", _assert_ok)
|
||||
|
|
@ -351,9 +351,9 @@ async def test_update_honors_changed_models_when_valid(
|
|||
patch = AutomationUpdate(
|
||||
definition=_definition(
|
||||
models=AutomationModels(
|
||||
agent_llm_id=-2,
|
||||
image_generation_config_id=9,
|
||||
vision_llm_config_id=-2,
|
||||
chat_model_id=-2,
|
||||
image_gen_model_id=9,
|
||||
vision_model_id=-2,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
|
@ -362,9 +362,9 @@ async def test_update_honors_changed_models_when_valid(
|
|||
|
||||
assert validated["ids"] == (-2, 9, -2)
|
||||
assert result.definition["models"] == {
|
||||
"agent_llm_id": -2,
|
||||
"image_generation_config_id": 9,
|
||||
"vision_llm_config_id": -2,
|
||||
"chat_model_id": -2,
|
||||
"image_gen_model_id": 9,
|
||||
"vision_model_id": -2,
|
||||
}
|
||||
assert result.version == 4
|
||||
|
||||
|
|
@ -379,17 +379,17 @@ async def test_update_rejects_changed_unbillable_models(
|
|||
"name": "A",
|
||||
"plan": [],
|
||||
"models": {
|
||||
"agent_llm_id": -1,
|
||||
"image_generation_config_id": 5,
|
||||
"vision_llm_config_id": -1,
|
||||
"chat_model_id": -1,
|
||||
"image_gen_model_id": 5,
|
||||
"vision_model_id": -1,
|
||||
},
|
||||
},
|
||||
version=3,
|
||||
)
|
||||
|
||||
def _raise(*, agent_llm_id, image_generation_config_id, vision_llm_config_id):
|
||||
def _raise(*, chat_model_id, image_gen_model_id, vision_model_id):
|
||||
raise AutomationModelPolicyError(
|
||||
[{"kind": "llm", "config_id": -7, "reason": "free model"}]
|
||||
[{"kind": "llm", "model_id": -7, "reason": "free model"}]
|
||||
)
|
||||
|
||||
monkeypatch.setattr(automation_mod, "assert_models_billable", _raise)
|
||||
|
|
@ -409,9 +409,9 @@ async def test_update_rejects_changed_unbillable_models(
|
|||
patch = AutomationUpdate(
|
||||
definition=_definition(
|
||||
models=AutomationModels(
|
||||
agent_llm_id=-7,
|
||||
image_generation_config_id=5,
|
||||
vision_llm_config_id=-1,
|
||||
chat_model_id=-7,
|
||||
image_gen_model_id=5,
|
||||
vision_model_id=-1,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
|
@ -431,9 +431,9 @@ async def test_update_keeps_unchanged_models_without_revalidation(
|
|||
premium without an unrelated edit tripping the policy check.
|
||||
"""
|
||||
captured = {
|
||||
"agent_llm_id": -1,
|
||||
"image_generation_config_id": 5,
|
||||
"vision_llm_config_id": -1,
|
||||
"chat_model_id": -1,
|
||||
"image_gen_model_id": 5,
|
||||
"vision_model_id": -1,
|
||||
}
|
||||
existing = SimpleNamespace(
|
||||
search_space_id=1,
|
||||
|
|
@ -485,7 +485,7 @@ async def test_model_eligibility_authorizes_and_returns_payload(
|
|||
lambda _ss: {"allowed": False, "violations": [{"kind": "image"}]},
|
||||
)
|
||||
|
||||
service = _service(SimpleNamespace(agent_llm_id=-2))
|
||||
service = _service(SimpleNamespace(chat_model_id=-2))
|
||||
result = await service.model_eligibility(search_space_id=5)
|
||||
|
||||
assert result == {"allowed": False, "violations": [{"kind": "image"}]}
|
||||
|
|
|
|||
|
|
@ -27,9 +27,9 @@ pytestmark = pytest.mark.unit
|
|||
def _search_space(*, llm: int | None, image: int | None, vision: int | None):
|
||||
"""Minimal stand-in for the ``SearchSpace`` ORM row the policy reads."""
|
||||
return SimpleNamespace(
|
||||
agent_llm_id=llm,
|
||||
image_generation_config_id=image,
|
||||
vision_llm_config_id=vision,
|
||||
chat_model_id=llm,
|
||||
image_gen_model_id=image,
|
||||
vision_model_id=vision,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -39,29 +39,11 @@ def patched_globals(monkeypatch: pytest.MonkeyPatch):
|
|||
|
||||
Negative ids: -1 is premium, -2 is free, for each of llm/image/vision.
|
||||
"""
|
||||
llm_configs = {
|
||||
-1: {"id": -1, "billing_tier": "premium"},
|
||||
-2: {"id": -2, "billing_tier": "free"},
|
||||
}
|
||||
monkeypatch.setattr(
|
||||
"app.agents.chat.runtime.llm_config.load_global_llm_config_by_id",
|
||||
lambda cid: llm_configs.get(cid),
|
||||
)
|
||||
|
||||
from app.config import config as app_config
|
||||
|
||||
monkeypatch.setattr(
|
||||
app_config,
|
||||
"GLOBAL_IMAGE_GEN_CONFIGS",
|
||||
[
|
||||
{"id": -1, "billing_tier": "premium"},
|
||||
{"id": -2, "billing_tier": "free"},
|
||||
],
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
app_config,
|
||||
"GLOBAL_VISION_LLM_CONFIGS",
|
||||
"GLOBAL_MODELS",
|
||||
[
|
||||
{"id": -1, "billing_tier": "premium"},
|
||||
{"id": -2, "billing_tier": "free"},
|
||||
|
|
@ -71,7 +53,7 @@ def patched_globals(monkeypatch: pytest.MonkeyPatch):
|
|||
return None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kind", ["llm", "image", "vision"])
|
||||
@pytest.mark.parametrize("kind", ["chat", "image", "vision"])
|
||||
def test_byok_positive_id_is_allowed(kind: str, patched_globals) -> None:
|
||||
"""A positive config id is a user-owned BYOK model — always billable."""
|
||||
allowed, reason = model_policy._classify(kind, 7)
|
||||
|
|
@ -79,7 +61,7 @@ def test_byok_positive_id_is_allowed(kind: str, patched_globals) -> None:
|
|||
assert reason == ""
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kind", ["llm", "image", "vision"])
|
||||
@pytest.mark.parametrize("kind", ["chat", "image", "vision"])
|
||||
@pytest.mark.parametrize("config_id", [0, None])
|
||||
def test_auto_mode_is_blocked(kind: str, config_id, patched_globals) -> None:
|
||||
"""Auto mode (id 0) and an unset slot (None) are blocked."""
|
||||
|
|
@ -88,7 +70,7 @@ def test_auto_mode_is_blocked(kind: str, config_id, patched_globals) -> None:
|
|||
assert "Auto mode" in reason
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kind", ["llm", "image", "vision"])
|
||||
@pytest.mark.parametrize("kind", ["chat", "image", "vision"])
|
||||
def test_premium_global_is_allowed(kind: str, patched_globals) -> None:
|
||||
"""A negative (global) id with premium billing tier is allowed."""
|
||||
allowed, reason = model_policy._classify(kind, -1)
|
||||
|
|
@ -96,7 +78,7 @@ def test_premium_global_is_allowed(kind: str, patched_globals) -> None:
|
|||
assert reason == ""
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kind", ["llm", "image", "vision"])
|
||||
@pytest.mark.parametrize("kind", ["chat", "image", "vision"])
|
||||
def test_free_global_is_blocked(kind: str, patched_globals) -> None:
|
||||
"""A negative (global) id with a free billing tier is blocked."""
|
||||
allowed, reason = model_policy._classify(kind, -2)
|
||||
|
|
@ -104,7 +86,7 @@ def test_free_global_is_blocked(kind: str, patched_globals) -> None:
|
|||
assert "free model" in reason
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kind", ["llm", "image", "vision"])
|
||||
@pytest.mark.parametrize("kind", ["chat", "image", "vision"])
|
||||
def test_unknown_global_id_is_blocked(kind: str, patched_globals) -> None:
|
||||
"""A negative id that resolves to no config is treated as not premium."""
|
||||
allowed, _ = model_policy._classify(kind, -999)
|
||||
|
|
@ -125,10 +107,10 @@ def test_eligibility_reports_each_violation(patched_globals) -> None:
|
|||
|
||||
assert result["allowed"] is False
|
||||
kinds = {v["kind"] for v in result["violations"]}
|
||||
assert kinds == {"llm", "image", "vision"}
|
||||
# config_id is echoed back for the UI / settings deep-link.
|
||||
by_kind = {v["kind"]: v["config_id"] for v in result["violations"]}
|
||||
assert by_kind == {"llm": -2, "image": 0, "vision": -2}
|
||||
assert kinds == {"chat", "image", "vision"}
|
||||
# model_id is echoed back for the UI / settings deep-link.
|
||||
by_kind = {v["kind"]: v["model_id"] for v in result["violations"]}
|
||||
assert by_kind == {"chat": -2, "image": 0, "vision": -2}
|
||||
|
||||
|
||||
def test_assert_raises_with_violations(patched_globals) -> None:
|
||||
|
|
@ -138,7 +120,7 @@ def test_assert_raises_with_violations(patched_globals) -> None:
|
|||
assert_automation_models_billable(search_space)
|
||||
|
||||
assert len(exc_info.value.violations) == 1
|
||||
assert exc_info.value.violations[0]["kind"] == "llm"
|
||||
assert exc_info.value.violations[0]["kind"] == "chat"
|
||||
|
||||
|
||||
def test_assert_passes_when_all_billable(patched_globals) -> None:
|
||||
|
|
@ -153,7 +135,7 @@ def test_assert_passes_when_all_billable(patched_globals) -> None:
|
|||
def test_get_model_eligibility_all_billable(patched_globals) -> None:
|
||||
"""Premium LLM + BYOK image + premium vision (explicit ids) → allowed."""
|
||||
result = get_model_eligibility(
|
||||
agent_llm_id=-1, image_generation_config_id=5, vision_llm_config_id=-1
|
||||
chat_model_id=-1, image_gen_model_id=5, vision_model_id=-1
|
||||
)
|
||||
assert result == {"allowed": True, "violations": []}
|
||||
|
||||
|
|
@ -161,28 +143,28 @@ def test_get_model_eligibility_all_billable(patched_globals) -> None:
|
|||
def test_get_model_eligibility_reports_each_violation(patched_globals) -> None:
|
||||
"""Free LLM, Auto image, free vision (explicit ids) each produce a violation."""
|
||||
result = get_model_eligibility(
|
||||
agent_llm_id=-2, image_generation_config_id=0, vision_llm_config_id=-2
|
||||
chat_model_id=-2, image_gen_model_id=0, vision_model_id=-2
|
||||
)
|
||||
assert result["allowed"] is False
|
||||
by_kind = {v["kind"]: v["config_id"] for v in result["violations"]}
|
||||
assert by_kind == {"llm": -2, "image": 0, "vision": -2}
|
||||
by_kind = {v["kind"]: v["model_id"] for v in result["violations"]}
|
||||
assert by_kind == {"chat": -2, "image": 0, "vision": -2}
|
||||
|
||||
|
||||
def test_assert_models_billable_raises(patched_globals) -> None:
|
||||
"""``assert_models_billable`` raises when any explicit id is blocked."""
|
||||
with pytest.raises(AutomationModelPolicyError) as exc_info:
|
||||
assert_models_billable(
|
||||
agent_llm_id=0, image_generation_config_id=5, vision_llm_config_id=-1
|
||||
chat_model_id=0, image_gen_model_id=5, vision_model_id=-1
|
||||
)
|
||||
assert len(exc_info.value.violations) == 1
|
||||
assert exc_info.value.violations[0]["kind"] == "llm"
|
||||
assert exc_info.value.violations[0]["kind"] == "chat"
|
||||
|
||||
|
||||
def test_assert_models_billable_passes(patched_globals) -> None:
|
||||
"""No exception when every explicit id is premium or BYOK."""
|
||||
assert (
|
||||
assert_models_billable(
|
||||
agent_llm_id=3, image_generation_config_id=-1, vision_llm_config_id=4
|
||||
chat_model_id=3, image_gen_model_id=-1, vision_model_id=4
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
|
@ -192,5 +174,5 @@ def test_search_space_wrapper_delegates_to_core(patched_globals) -> None:
|
|||
"""The search-space wrapper produces the same result as the ID core."""
|
||||
search_space = _search_space(llm=-2, image=0, vision=-2)
|
||||
assert get_automation_model_eligibility(search_space) == get_model_eligibility(
|
||||
agent_llm_id=-2, image_generation_config_id=0, vision_llm_config_id=-2
|
||||
chat_model_id=-2, image_gen_model_id=0, vision_model_id=-2
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,110 +0,0 @@
|
|||
"""Unit tests for ``supports_image_input`` derivation on BYOK chat config
|
||||
endpoints (``GET /new-llm-configs`` list, ``GET /new-llm-configs/{id}``).
|
||||
|
||||
There is no DB column for ``supports_image_input`` on
|
||||
``NewLLMConfig`` — the value is resolved at the API boundary by
|
||||
``derive_supports_image_input`` so the new-chat selector / streaming
|
||||
task can read the same field shape regardless of source (BYOK vs YAML
|
||||
vs OpenRouter dynamic). Default-allow on unknown so we don't lock the
|
||||
user out of their own model choice.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.db import LiteLLMProvider
|
||||
from app.routes import new_llm_config_routes
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _byok_row(
|
||||
*,
|
||||
id_: int,
|
||||
model_name: str,
|
||||
base_model: str | None = None,
|
||||
provider: LiteLLMProvider = LiteLLMProvider.OPENAI,
|
||||
custom_provider: str | None = None,
|
||||
) -> object:
|
||||
"""Mimic the SQLAlchemy row's attribute surface; ``model_validate``
|
||||
walks ``from_attributes=True`` so a ``SimpleNamespace`` is enough.
|
||||
|
||||
``provider`` is a real ``LiteLLMProvider`` enum value so Pydantic's
|
||||
enum validator accepts it — same as the ORM row would carry."""
|
||||
return SimpleNamespace(
|
||||
id=id_,
|
||||
name=f"BYOK-{id_}",
|
||||
description=None,
|
||||
provider=provider,
|
||||
custom_provider=custom_provider,
|
||||
model_name=model_name,
|
||||
api_key="sk-byok",
|
||||
api_base=None,
|
||||
litellm_params={"base_model": base_model} if base_model else None,
|
||||
system_instructions="",
|
||||
use_default_system_instructions=True,
|
||||
citations_enabled=True,
|
||||
created_at=datetime.now(tz=UTC),
|
||||
search_space_id=42,
|
||||
user_id=uuid4(),
|
||||
)
|
||||
|
||||
|
||||
def test_serialize_byok_known_vision_model_resolves_true():
|
||||
"""The catalog resolver consults LiteLLM's map for ``gpt-4o`` ->
|
||||
True. The serialized row carries that value through to the
|
||||
``NewLLMConfigRead`` schema."""
|
||||
row = _byok_row(id_=1, model_name="gpt-4o")
|
||||
serialized = new_llm_config_routes._serialize_byok_config(row)
|
||||
|
||||
assert serialized.supports_image_input is True
|
||||
assert serialized.id == 1
|
||||
assert serialized.model_name == "gpt-4o"
|
||||
|
||||
|
||||
def test_serialize_byok_unknown_model_default_allows():
|
||||
"""Unknown / unmapped: default-allow. The streaming-task safety net
|
||||
is the actual block, and it requires LiteLLM to *explicitly* say
|
||||
text-only — so a brand new BYOK model should not be pre-judged."""
|
||||
row = _byok_row(
|
||||
id_=2,
|
||||
model_name="brand-new-model-x9-unmapped",
|
||||
provider=LiteLLMProvider.CUSTOM,
|
||||
custom_provider="brand_new_proxy",
|
||||
)
|
||||
serialized = new_llm_config_routes._serialize_byok_config(row)
|
||||
|
||||
assert serialized.supports_image_input is True
|
||||
|
||||
|
||||
def test_serialize_byok_uses_base_model_when_present():
|
||||
"""Azure-style: ``model_name`` is the deployment id, ``base_model``
|
||||
inside ``litellm_params`` is the canonical sku LiteLLM knows. The
|
||||
helper must consult ``base_model`` first or unrecognised deployment
|
||||
ids would shadow the real capability."""
|
||||
row = _byok_row(
|
||||
id_=3,
|
||||
model_name="my-azure-deployment-id-no-litellm-knows-this",
|
||||
base_model="gpt-4o",
|
||||
provider=LiteLLMProvider.AZURE_OPENAI,
|
||||
)
|
||||
serialized = new_llm_config_routes._serialize_byok_config(row)
|
||||
|
||||
assert serialized.supports_image_input is True
|
||||
|
||||
|
||||
def test_serialize_byok_returns_pydantic_read_model():
|
||||
"""The route now returns ``NewLLMConfigRead`` (not the raw ORM) so
|
||||
the schema additions are guaranteed to be present in the API
|
||||
surface. This guards against a future regression where someone
|
||||
deletes the augmentation step and falls back to ORM passthrough."""
|
||||
from app.schemas import NewLLMConfigRead
|
||||
|
||||
row = _byok_row(id_=4, model_name="gpt-4o")
|
||||
serialized = new_llm_config_routes._serialize_byok_config(row)
|
||||
assert isinstance(serialized, NewLLMConfigRead)
|
||||
|
|
@ -1,184 +0,0 @@
|
|||
"""Unit tests for ``is_premium`` derivation on the global image-gen and
|
||||
vision-LLM list endpoints.
|
||||
|
||||
Chat globals (``GET /global-llm-configs``) already emit
|
||||
``is_premium = (billing_tier == "premium")``. Image and vision did not,
|
||||
which made the new-chat ``model-selector`` render the Free/Premium badge
|
||||
on the Chat tab but skip it on the Image and Vision tabs (the selector
|
||||
keys its badge logic off ``is_premium``). These tests pin parity:
|
||||
|
||||
* YAML free entry → ``is_premium=False``
|
||||
* YAML premium entry → ``is_premium=True``
|
||||
* OpenRouter dynamic premium entry → ``is_premium=True``
|
||||
* Auto stub (always emitted when at least one config is present)
|
||||
→ ``is_premium=False``
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
_IMAGE_FIXTURE: list[dict] = [
|
||||
{
|
||||
"id": -1,
|
||||
"name": "DALL-E 3",
|
||||
"provider": "OPENAI",
|
||||
"model_name": "dall-e-3",
|
||||
"api_key": "sk-test",
|
||||
"billing_tier": "free",
|
||||
},
|
||||
{
|
||||
"id": -2,
|
||||
"name": "GPT-Image 1 (premium)",
|
||||
"provider": "OPENAI",
|
||||
"model_name": "gpt-image-1",
|
||||
"api_key": "sk-test",
|
||||
"billing_tier": "premium",
|
||||
},
|
||||
{
|
||||
"id": -20_001,
|
||||
"name": "google/gemini-2.5-flash-image (OpenRouter)",
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "google/gemini-2.5-flash-image",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
"billing_tier": "premium",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
_VISION_FIXTURE: list[dict] = [
|
||||
{
|
||||
"id": -1,
|
||||
"name": "GPT-4o Vision",
|
||||
"provider": "OPENAI",
|
||||
"model_name": "gpt-4o",
|
||||
"api_key": "sk-test",
|
||||
"billing_tier": "free",
|
||||
},
|
||||
{
|
||||
"id": -2,
|
||||
"name": "Claude 3.5 Sonnet (premium)",
|
||||
"provider": "ANTHROPIC",
|
||||
"model_name": "claude-3-5-sonnet",
|
||||
"api_key": "sk-ant-test",
|
||||
"billing_tier": "premium",
|
||||
},
|
||||
{
|
||||
"id": -30_001,
|
||||
"name": "openai/gpt-4o (OpenRouter)",
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "openai/gpt-4o",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
"billing_tier": "premium",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Image generation
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_image_gen_configs_emit_is_premium(monkeypatch):
|
||||
"""Each emitted config must carry ``is_premium`` derived server-side
|
||||
from ``billing_tier``. The Auto stub is always free.
|
||||
"""
|
||||
from app.config import config
|
||||
from app.routes import image_generation_routes
|
||||
|
||||
monkeypatch.setattr(
|
||||
config, "GLOBAL_IMAGE_GEN_CONFIGS", _IMAGE_FIXTURE, raising=False
|
||||
)
|
||||
|
||||
payload = await image_generation_routes.get_global_image_gen_configs(user=None)
|
||||
|
||||
by_id = {c["id"]: c for c in payload}
|
||||
|
||||
# Auto stub is always emitted when at least one global config exists,
|
||||
# and it must always declare itself free (Auto-mode billing-tier
|
||||
# surfacing is a separate follow-up).
|
||||
assert 0 in by_id, "Auto stub should be emitted when at least one config exists"
|
||||
assert by_id[0]["is_premium"] is False
|
||||
assert by_id[0]["billing_tier"] == "free"
|
||||
|
||||
# YAML free entry — ``is_premium=False``
|
||||
assert by_id[-1]["is_premium"] is False
|
||||
assert by_id[-1]["billing_tier"] == "free"
|
||||
|
||||
# YAML premium entry — ``is_premium=True``
|
||||
assert by_id[-2]["is_premium"] is True
|
||||
assert by_id[-2]["billing_tier"] == "premium"
|
||||
|
||||
# OpenRouter dynamic premium entry — same field, same derivation
|
||||
assert by_id[-20_001]["is_premium"] is True
|
||||
assert by_id[-20_001]["billing_tier"] == "premium"
|
||||
|
||||
# Every emitted dict (including Auto) must have the field — never missing.
|
||||
for cfg in payload:
|
||||
assert "is_premium" in cfg, f"is_premium missing from {cfg.get('id')}"
|
||||
assert isinstance(cfg["is_premium"], bool)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_image_gen_configs_no_globals_no_auto_stub(monkeypatch):
|
||||
"""When there are no global configs at all, the endpoint emits an
|
||||
empty list (no Auto stub) — Auto mode would have nothing to route to.
|
||||
"""
|
||||
from app.config import config
|
||||
from app.routes import image_generation_routes
|
||||
|
||||
monkeypatch.setattr(config, "GLOBAL_IMAGE_GEN_CONFIGS", [], raising=False)
|
||||
payload = await image_generation_routes.get_global_image_gen_configs(user=None)
|
||||
assert payload == []
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Vision LLM
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_vision_llm_configs_emit_is_premium(monkeypatch):
|
||||
from app.config import config
|
||||
from app.routes import vision_llm_routes
|
||||
|
||||
monkeypatch.setattr(
|
||||
config, "GLOBAL_VISION_LLM_CONFIGS", _VISION_FIXTURE, raising=False
|
||||
)
|
||||
|
||||
payload = await vision_llm_routes.get_global_vision_llm_configs(user=None)
|
||||
|
||||
by_id = {c["id"]: c for c in payload}
|
||||
|
||||
assert 0 in by_id, "Auto stub should be emitted when at least one config exists"
|
||||
assert by_id[0]["is_premium"] is False
|
||||
assert by_id[0]["billing_tier"] == "free"
|
||||
|
||||
assert by_id[-1]["is_premium"] is False
|
||||
assert by_id[-1]["billing_tier"] == "free"
|
||||
|
||||
assert by_id[-2]["is_premium"] is True
|
||||
assert by_id[-2]["billing_tier"] == "premium"
|
||||
|
||||
assert by_id[-30_001]["is_premium"] is True
|
||||
assert by_id[-30_001]["billing_tier"] == "premium"
|
||||
|
||||
for cfg in payload:
|
||||
assert "is_premium" in cfg, f"is_premium missing from {cfg.get('id')}"
|
||||
assert isinstance(cfg["is_premium"], bool)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_vision_llm_configs_no_globals_no_auto_stub(monkeypatch):
|
||||
from app.config import config
|
||||
from app.routes import vision_llm_routes
|
||||
|
||||
monkeypatch.setattr(config, "GLOBAL_VISION_LLM_CONFIGS", [], raising=False)
|
||||
payload = await vision_llm_routes.get_global_vision_llm_configs(user=None)
|
||||
assert payload == []
|
||||
|
|
@ -1,106 +0,0 @@
|
|||
"""Unit tests for ``supports_image_input`` derivation on the chat global
|
||||
config endpoint (``GET /global-new-llm-configs``).
|
||||
|
||||
Resolution order (matches ``new_llm_config_routes.get_global_new_llm_configs``):
|
||||
|
||||
1. Explicit ``supports_image_input`` on the cfg dict (set by the YAML
|
||||
loader for operator overrides, or by the OpenRouter integration from
|
||||
``architecture.input_modalities``) — wins.
|
||||
2. ``derive_supports_image_input`` helper — default-allow on unknown
|
||||
models, only False when LiteLLM / OR modalities are definitive.
|
||||
|
||||
The flag is purely informational at the API boundary. The streaming
|
||||
task safety net (``is_known_text_only_chat_model``) is the actual block,
|
||||
and it requires LiteLLM to *explicitly* mark the model as text-only.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
_FIXTURE: list[dict] = [
|
||||
{
|
||||
"id": -1,
|
||||
"name": "GPT-4o (explicit true)",
|
||||
"description": "vision-capable, explicit YAML override",
|
||||
"provider": "OPENAI",
|
||||
"model_name": "gpt-4o",
|
||||
"api_key": "sk-test",
|
||||
"billing_tier": "free",
|
||||
"supports_image_input": True,
|
||||
},
|
||||
{
|
||||
"id": -2,
|
||||
"name": "DeepSeek V3 (explicit false)",
|
||||
"description": "OpenRouter dynamic — modality-derived false",
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "deepseek/deepseek-v3.2-exp",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
"billing_tier": "free",
|
||||
"supports_image_input": False,
|
||||
},
|
||||
{
|
||||
"id": -10_010,
|
||||
"name": "Unannotated GPT-4o",
|
||||
"description": "no flag set — resolver should derive True via LiteLLM",
|
||||
"provider": "OPENAI",
|
||||
"model_name": "gpt-4o",
|
||||
"api_key": "sk-test",
|
||||
"billing_tier": "free",
|
||||
# supports_image_input intentionally absent
|
||||
},
|
||||
{
|
||||
"id": -10_011,
|
||||
"name": "Unannotated unknown model",
|
||||
"description": "unmapped — default-allow True",
|
||||
"provider": "CUSTOM",
|
||||
"custom_provider": "brand_new_proxy",
|
||||
"model_name": "brand-new-model-x9",
|
||||
"api_key": "sk-test",
|
||||
"billing_tier": "free",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_new_llm_configs_emit_supports_image_input(monkeypatch):
|
||||
"""Each emitted chat config carries ``supports_image_input`` as a
|
||||
bool. Explicit values win; unannotated entries are resolved via the
|
||||
helper (default-allow True)."""
|
||||
from app.config import config
|
||||
from app.routes import new_llm_config_routes
|
||||
|
||||
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", _FIXTURE, raising=False)
|
||||
|
||||
payload = await new_llm_config_routes.get_global_new_llm_configs(user=None)
|
||||
by_id = {c["id"]: c for c in payload}
|
||||
|
||||
# Auto stub: optimistic True so the user can keep Auto selected with
|
||||
# vision-capable deployments somewhere in the pool.
|
||||
assert 0 in by_id, "Auto stub should be emitted when configs exist"
|
||||
assert by_id[0]["supports_image_input"] is True
|
||||
assert by_id[0]["is_auto_mode"] is True
|
||||
|
||||
# Explicit True is preserved.
|
||||
assert by_id[-1]["supports_image_input"] is True
|
||||
|
||||
# Explicit False is preserved (the exact failure mode the safety net
|
||||
# guards against — DeepSeek V3 over OpenRouter would 404 with "No
|
||||
# endpoints found that support image input").
|
||||
assert by_id[-2]["supports_image_input"] is False
|
||||
|
||||
# Unannotated GPT-4o: resolver consults LiteLLM, which says vision.
|
||||
assert by_id[-10_010]["supports_image_input"] is True
|
||||
|
||||
# Unknown / unmapped model: default-allow rather than pre-judge.
|
||||
assert by_id[-10_011]["supports_image_input"] is True
|
||||
|
||||
for cfg in payload:
|
||||
assert "supports_image_input" in cfg, (
|
||||
f"supports_image_input missing from {cfg.get('id')}"
|
||||
)
|
||||
assert isinstance(cfg["supports_image_input"], bool)
|
||||
|
|
@ -27,9 +27,18 @@ async def test_resolve_billing_for_auto_mode(monkeypatch):
|
|||
from app.routes import image_generation_routes
|
||||
from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS
|
||||
|
||||
search_space = SimpleNamespace(image_generation_config_id=None)
|
||||
async def _no_auto_candidates(*_args, **_kwargs):
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(
|
||||
image_generation_routes,
|
||||
"auto_model_candidates",
|
||||
_no_auto_candidates,
|
||||
)
|
||||
|
||||
search_space = SimpleNamespace(id=1, user_id=None, image_gen_model_id=None)
|
||||
tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
|
||||
session=None, # Not consumed on this code path.
|
||||
session=None,
|
||||
config_id=0, # IMAGE_GEN_AUTO_MODE_ID
|
||||
search_space=search_space,
|
||||
)
|
||||
|
|
@ -45,26 +54,48 @@ async def test_resolve_billing_for_premium_global_config(monkeypatch):
|
|||
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_IMAGE_GEN_CONFIGS",
|
||||
"GLOBAL_MODELS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENAI",
|
||||
"model_name": "gpt-image-1",
|
||||
"connection_id": -101,
|
||||
"model_id": "gpt-image-1",
|
||||
"billing_tier": "premium",
|
||||
"quota_reserve_micros": 75_000,
|
||||
"catalog": {"quota_reserve_micros": 75_000},
|
||||
},
|
||||
{
|
||||
"id": -2,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "google/gemini-2.5-flash-image",
|
||||
"connection_id": -102,
|
||||
"model_id": "google/gemini-2.5-flash-image",
|
||||
"billing_tier": "free",
|
||||
"catalog": {},
|
||||
},
|
||||
],
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_CONNECTIONS",
|
||||
[
|
||||
{
|
||||
"id": -101,
|
||||
"provider": "openai",
|
||||
"api_key": "sk-test",
|
||||
"base_url": None,
|
||||
"extra": {},
|
||||
},
|
||||
{
|
||||
"id": -102,
|
||||
"provider": "openrouter",
|
||||
"api_key": "sk-or-test",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
"extra": {},
|
||||
},
|
||||
],
|
||||
raising=False,
|
||||
)
|
||||
|
||||
search_space = SimpleNamespace(image_generation_config_id=None)
|
||||
search_space = SimpleNamespace(id=1, user_id=None, image_gen_model_id=None)
|
||||
|
||||
# Premium with override.
|
||||
tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
|
||||
|
|
@ -94,7 +125,7 @@ async def test_resolve_billing_for_user_owned_byok_is_free():
|
|||
from app.routes import image_generation_routes
|
||||
from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS
|
||||
|
||||
search_space = SimpleNamespace(image_generation_config_id=None)
|
||||
search_space = SimpleNamespace(id=1, user_id=None, image_gen_model_id=None)
|
||||
tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
|
||||
session=None, config_id=42, search_space=search_space
|
||||
)
|
||||
|
|
@ -105,7 +136,7 @@ async def test_resolve_billing_for_user_owned_byok_is_free():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_billing_falls_back_to_search_space_default(monkeypatch):
|
||||
"""When the request omits ``image_generation_config_id``, the helper
|
||||
"""When the request omits ``image_gen_model_id``, the helper
|
||||
must consult the search space's default — so a search space pinned
|
||||
to a premium global config still gates new requests by quota.
|
||||
"""
|
||||
|
|
@ -114,19 +145,34 @@ async def test_resolve_billing_falls_back_to_search_space_default(monkeypatch):
|
|||
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_IMAGE_GEN_CONFIGS",
|
||||
"GLOBAL_MODELS",
|
||||
[
|
||||
{
|
||||
"id": -7,
|
||||
"provider": "OPENAI",
|
||||
"model_name": "gpt-image-1",
|
||||
"connection_id": -101,
|
||||
"model_id": "gpt-image-1",
|
||||
"billing_tier": "premium",
|
||||
"catalog": {},
|
||||
}
|
||||
],
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_CONNECTIONS",
|
||||
[
|
||||
{
|
||||
"id": -101,
|
||||
"provider": "openai",
|
||||
"api_key": "sk-test",
|
||||
"base_url": None,
|
||||
"extra": {},
|
||||
}
|
||||
],
|
||||
raising=False,
|
||||
)
|
||||
|
||||
search_space = SimpleNamespace(image_generation_config_id=-7)
|
||||
search_space = SimpleNamespace(id=1, user_id=None, image_gen_model_id=-7)
|
||||
(
|
||||
tier,
|
||||
model,
|
||||
|
|
|
|||
|
|
@ -1,27 +1,4 @@
|
|||
"""Unit tests for ``_resolve_agent_billing_for_search_space``.
|
||||
|
||||
Validates the resolver used by Celery podcast/video tasks to compute
|
||||
``(owner_user_id, billing_tier, base_model)`` from a search space and its
|
||||
agent LLM config. The resolver mirrors chat's billing-resolution pattern at
|
||||
``stream_new_chat.py:2294-2351`` and is the single integration point that
|
||||
prevents Auto-mode podcast/video from leaking premium credit.
|
||||
|
||||
Coverage:
|
||||
|
||||
* Auto mode + ``thread_id`` set, pin resolves to a negative-id premium
|
||||
global → returns ``("premium", <base_model>)``.
|
||||
* Auto mode + ``thread_id`` set, pin resolves to a negative-id free
|
||||
global → returns ``("free", <base_model>)``.
|
||||
* Auto mode + ``thread_id`` set, pin resolves to a positive-id BYOK config
|
||||
→ always ``"free"``.
|
||||
* Auto mode + ``thread_id=None`` → fallback to ``("free", "auto")`` without
|
||||
hitting the pin service.
|
||||
* Negative id (no Auto) → uses ``get_global_llm_config``'s
|
||||
``billing_tier``.
|
||||
* Positive id (user BYOK) → always ``"free"``.
|
||||
* Search space not found → raises ``ValueError``.
|
||||
* ``agent_llm_id`` is None → raises ``ValueError``.
|
||||
"""
|
||||
"""Unit tests for ``_resolve_agent_billing_for_search_space``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
|
@ -34,11 +11,6 @@ import pytest
|
|||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fakes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeExecResult:
|
||||
def __init__(self, obj):
|
||||
self._obj = obj
|
||||
|
|
@ -51,14 +23,6 @@ class _FakeExecResult:
|
|||
|
||||
|
||||
class _FakeSession:
|
||||
"""Tiny AsyncSession stub.
|
||||
|
||||
``responses`` is a list of objects to return from successive
|
||||
``execute()`` calls (in order). The resolver makes at most two
|
||||
``execute()`` calls (search-space lookup, then optionally NewLLMConfig
|
||||
lookup), so two queued responses cover the matrix.
|
||||
"""
|
||||
|
||||
def __init__(self, responses: list):
|
||||
self._responses = list(responses)
|
||||
|
||||
|
|
@ -67,9 +31,6 @@ class _FakeSession:
|
|||
return _FakeExecResult(None)
|
||||
return _FakeExecResult(self._responses.pop(0))
|
||||
|
||||
async def commit(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakePinResolution:
|
||||
|
|
@ -78,53 +39,33 @@ class _FakePinResolution:
|
|||
from_existing_pin: bool = False
|
||||
|
||||
|
||||
def _make_search_space(*, agent_llm_id: int | None, user_id: UUID) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
id=42,
|
||||
agent_llm_id=agent_llm_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
def _make_search_space(*, chat_model_id: int | None, user_id: UUID) -> SimpleNamespace:
|
||||
return SimpleNamespace(id=42, chat_model_id=chat_model_id, user_id=user_id)
|
||||
|
||||
|
||||
def _make_byok_config(
|
||||
*, id_: int, base_model: str | None = None, model_name: str = "gpt-byok"
|
||||
def _make_byok_model(
|
||||
*, id_: int, base_model: str | None = None, model_id: str = "gpt-byok"
|
||||
) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
id=id_,
|
||||
model_name=model_name,
|
||||
litellm_params={"base_model": base_model} if base_model else {},
|
||||
model_id=model_id,
|
||||
catalog={"base_model": base_model} if base_model else {},
|
||||
connection=SimpleNamespace(enabled=True, search_space_id=42, user_id=None),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_mode_with_thread_id_resolves_to_premium_global(monkeypatch):
|
||||
"""Auto + thread → pin service resolves to negative-id premium config →
|
||||
resolver returns ``("premium", <base_model>)``."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
|
||||
session = _FakeSession([_make_search_space(chat_model_id=0, user_id=user_id)])
|
||||
|
||||
# Mock the pin service to return a concrete premium config id.
|
||||
async def _fake_resolve_pin(
|
||||
sess,
|
||||
*,
|
||||
thread_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
selected_llm_config_id,
|
||||
force_repin_free=False,
|
||||
):
|
||||
assert selected_llm_config_id == 0
|
||||
assert thread_id == 99
|
||||
async def _fake_resolve_pin(*_args, **kwargs):
|
||||
assert kwargs["selected_llm_config_id"] == 0
|
||||
assert kwargs["thread_id"] == 99
|
||||
return _FakePinResolution(resolved_llm_config_id=-1, resolved_tier="premium")
|
||||
|
||||
# Mock global config lookup to return a premium entry.
|
||||
def _fake_get_global(cfg_id):
|
||||
if cfg_id == -1:
|
||||
return {
|
||||
|
|
@ -135,8 +76,6 @@ async def test_auto_mode_with_thread_id_resolves_to_premium_global(monkeypatch):
|
|||
}
|
||||
return None
|
||||
|
||||
# Lazy imports inside the resolver — patch the *target* modules so the
|
||||
# imported names resolve to our fakes.
|
||||
import app.services.auto_model_pin_service as pin_module
|
||||
import app.services.llm_service as llm_module
|
||||
|
||||
|
|
@ -154,77 +93,18 @@ async def test_auto_mode_with_thread_id_resolves_to_premium_global(monkeypatch):
|
|||
assert base_model == "gpt-5.4"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_mode_with_thread_id_resolves_to_free_global(monkeypatch):
|
||||
"""Auto + thread → pin returns negative-id free config → resolver
|
||||
returns ``("free", <base_model>)``. Same path the pin service takes for
|
||||
out-of-credit users (graceful degradation)."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
|
||||
|
||||
async def _fake_resolve_pin(
|
||||
sess,
|
||||
*,
|
||||
thread_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
selected_llm_config_id,
|
||||
force_repin_free=False,
|
||||
):
|
||||
return _FakePinResolution(resolved_llm_config_id=-3, resolved_tier="free")
|
||||
|
||||
def _fake_get_global(cfg_id):
|
||||
if cfg_id == -3:
|
||||
return {
|
||||
"id": -3,
|
||||
"model_name": "openrouter/free-model",
|
||||
"billing_tier": "free",
|
||||
"litellm_params": {"base_model": "openrouter/free-model"},
|
||||
}
|
||||
return None
|
||||
|
||||
import app.services.auto_model_pin_service as pin_module
|
||||
import app.services.llm_service as llm_module
|
||||
|
||||
monkeypatch.setattr(
|
||||
pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
|
||||
)
|
||||
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42, thread_id=99
|
||||
)
|
||||
|
||||
assert owner == user_id
|
||||
assert tier == "free"
|
||||
assert base_model == "openrouter/free-model"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_mode_with_thread_id_resolves_to_byok_is_free(monkeypatch):
|
||||
"""Auto + thread → pin returns positive-id BYOK config → resolver
|
||||
returns ``("free", ...)`` (BYOK is always free per
|
||||
``AgentConfig.from_new_llm_config``)."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
search_space = _make_search_space(agent_llm_id=0, user_id=user_id)
|
||||
byok_cfg = _make_byok_config(
|
||||
id_=17, base_model="anthropic/claude-3-haiku", model_name="my-claude"
|
||||
search_space = _make_search_space(chat_model_id=0, user_id=user_id)
|
||||
byok_model = _make_byok_model(
|
||||
id_=17, base_model="anthropic/claude-3-haiku", model_id="my-claude"
|
||||
)
|
||||
session = _FakeSession([search_space, byok_cfg])
|
||||
session = _FakeSession([search_space, byok_model])
|
||||
|
||||
async def _fake_resolve_pin(
|
||||
sess,
|
||||
*,
|
||||
thread_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
selected_llm_config_id,
|
||||
force_repin_free=False,
|
||||
):
|
||||
async def _fake_resolve_pin(*_args, **_kwargs):
|
||||
return _FakePinResolution(resolved_llm_config_id=17, resolved_tier="free")
|
||||
|
||||
import app.services.auto_model_pin_service as pin_module
|
||||
|
|
@ -244,13 +124,10 @@ async def test_auto_mode_with_thread_id_resolves_to_byok_is_free(monkeypatch):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_mode_without_thread_id_falls_back_to_free():
|
||||
"""Auto + ``thread_id=None`` → ``("free", "auto")`` without invoking
|
||||
the pin service. Forward-compat fallback for any future direct-API
|
||||
entrypoint that doesn't have a chat thread."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
|
||||
session = _FakeSession([_make_search_space(chat_model_id=0, user_id=user_id)])
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42, thread_id=None
|
||||
|
|
@ -263,13 +140,10 @@ async def test_auto_mode_without_thread_id_falls_back_to_free():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_mode_pin_failure_falls_back_to_free(monkeypatch):
|
||||
"""If the pin service raises ``ValueError`` (thread missing /
|
||||
mismatched search space), the resolver should log and return free
|
||||
rather than killing the whole task."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
|
||||
session = _FakeSession([_make_search_space(chat_model_id=0, user_id=user_id)])
|
||||
|
||||
async def _fake_resolve_pin(*args, **kwargs):
|
||||
raise ValueError("thread missing")
|
||||
|
|
@ -291,12 +165,10 @@ async def test_auto_mode_pin_failure_falls_back_to_free(monkeypatch):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_negative_id_premium_global_returns_premium(monkeypatch):
|
||||
"""Explicit negative agent_llm_id → ``get_global_llm_config`` →
|
||||
return its ``billing_tier``."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=-1, user_id=user_id)])
|
||||
session = _FakeSession([_make_search_space(chat_model_id=-1, user_id=user_id)])
|
||||
|
||||
def _fake_get_global(cfg_id):
|
||||
return {
|
||||
|
|
@ -319,50 +191,15 @@ async def test_negative_id_premium_global_returns_premium(monkeypatch):
|
|||
assert base_model == "gpt-5.4"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_negative_id_free_global_returns_free(monkeypatch):
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=-2, user_id=user_id)])
|
||||
|
||||
def _fake_get_global(cfg_id):
|
||||
return {
|
||||
"id": cfg_id,
|
||||
"model_name": "openrouter/some-free",
|
||||
"billing_tier": "free",
|
||||
"litellm_params": {"base_model": "openrouter/some-free"},
|
||||
}
|
||||
|
||||
import app.services.llm_service as llm_module
|
||||
|
||||
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42, thread_id=None
|
||||
)
|
||||
|
||||
assert owner == user_id
|
||||
assert tier == "free"
|
||||
assert base_model == "openrouter/some-free"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_negative_id_missing_base_model_falls_back_to_model_name(monkeypatch):
|
||||
"""When the global config has no ``litellm_params.base_model``, the
|
||||
resolver falls back to ``model_name`` — matching chat's behavior."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=-5, user_id=user_id)])
|
||||
session = _FakeSession([_make_search_space(chat_model_id=-5, user_id=user_id)])
|
||||
|
||||
def _fake_get_global(cfg_id):
|
||||
return {
|
||||
"id": cfg_id,
|
||||
"model_name": "fallback-model",
|
||||
"billing_tier": "premium",
|
||||
# No litellm_params.
|
||||
}
|
||||
return {"id": cfg_id, "model_name": "fallback-model", "billing_tier": "premium"}
|
||||
|
||||
import app.services.llm_service as llm_module
|
||||
|
||||
|
|
@ -378,14 +215,12 @@ async def test_negative_id_missing_base_model_falls_back_to_model_name(monkeypat
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_positive_id_byok_is_always_free():
|
||||
"""Positive agent_llm_id → user-owned BYOK NewLLMConfig → always free,
|
||||
regardless of underlying provider tier."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
search_space = _make_search_space(agent_llm_id=23, user_id=user_id)
|
||||
byok_cfg = _make_byok_config(id_=23, base_model="anthropic/claude-3.5-sonnet")
|
||||
session = _FakeSession([search_space, byok_cfg])
|
||||
search_space = _make_search_space(chat_model_id=23, user_id=user_id)
|
||||
byok_model = _make_byok_model(id_=23, base_model="anthropic/claude-3.5-sonnet")
|
||||
session = _FakeSession([search_space, byok_model])
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42
|
||||
|
|
@ -398,13 +233,10 @@ async def test_positive_id_byok_is_always_free():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_positive_id_byok_missing_returns_free_with_empty_base_model():
|
||||
"""If the BYOK config row is missing/deleted but the search space still
|
||||
points at it, the resolver still returns free (no debit) with an empty
|
||||
base_model — billable_call's premium path is skipped, no harm done."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=99, user_id=user_id)])
|
||||
session = _FakeSession([_make_search_space(chat_model_id=99, user_id=user_id)])
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42
|
||||
|
|
@ -419,18 +251,18 @@ async def test_positive_id_byok_missing_returns_free_with_empty_base_model():
|
|||
async def test_search_space_not_found_raises_value_error():
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
session = _FakeSession([None])
|
||||
|
||||
with pytest.raises(ValueError, match="Search space"):
|
||||
await _resolve_agent_billing_for_search_space(session, search_space_id=999)
|
||||
await _resolve_agent_billing_for_search_space(
|
||||
_FakeSession([None]), search_space_id=999
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_llm_id_none_raises_value_error():
|
||||
async def test_chat_model_id_none_raises_value_error():
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=None, user_id=user_id)])
|
||||
session = _FakeSession([_make_search_space(chat_model_id=None, user_id=user_id)])
|
||||
|
||||
with pytest.raises(ValueError, match="agent_llm_id"):
|
||||
with pytest.raises(ValueError, match="chat_model_id"):
|
||||
await _resolve_agent_billing_for_search_space(session, search_space_id=42)
|
||||
|
|
|
|||
|
|
@ -17,8 +17,39 @@ from app.services.auto_model_pin_service import (
|
|||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class _FakeRedis:
|
||||
def __init__(self):
|
||||
self.values: dict[str, str] = {}
|
||||
self.ttls: dict[str, int] = {}
|
||||
|
||||
def set(self, key: str, value: str, *, ex: int | None = None):
|
||||
self.values[key] = value
|
||||
if ex is not None:
|
||||
self.ttls[key] = ex
|
||||
return True
|
||||
|
||||
def mget(self, keys: list[str]):
|
||||
return [self.values.get(key) for key in keys]
|
||||
|
||||
def delete(self, *keys: str):
|
||||
removed = 0
|
||||
for key in keys:
|
||||
if key in self.values:
|
||||
removed += 1
|
||||
self.values.pop(key, None)
|
||||
self.ttls.pop(key, None)
|
||||
return removed
|
||||
|
||||
def scan_iter(self, pattern: str):
|
||||
prefix = pattern.removesuffix("*")
|
||||
return (key for key in list(self.values) if key.startswith(prefix))
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_runtime_cooldown_map():
|
||||
def _clear_runtime_cooldown_map(monkeypatch):
|
||||
import app.services.auto_model_pin_service as svc
|
||||
|
||||
monkeypatch.setattr(svc, "_runtime_cooldown_redis", _FakeRedis())
|
||||
clear_runtime_cooldown()
|
||||
clear_healthy()
|
||||
yield
|
||||
|
|
@ -32,8 +63,9 @@ class _FakeQuotaResult:
|
|||
|
||||
|
||||
class _FakeExecResult:
|
||||
def __init__(self, thread):
|
||||
def __init__(self, *, thread=None, scalars=None):
|
||||
self._thread = thread
|
||||
self._scalars = scalars or []
|
||||
|
||||
def unique(self):
|
||||
return self
|
||||
|
|
@ -41,19 +73,71 @@ class _FakeExecResult:
|
|||
def scalar_one_or_none(self):
|
||||
return self._thread
|
||||
|
||||
def scalars(self):
|
||||
return SimpleNamespace(all=lambda: self._scalars)
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, thread):
|
||||
def __init__(self, thread, *, models=None):
|
||||
self.thread = thread
|
||||
self.models = models or []
|
||||
self.commit_count = 0
|
||||
self.execute_count = 0
|
||||
|
||||
async def execute(self, _stmt):
|
||||
return _FakeExecResult(self.thread)
|
||||
self.execute_count += 1
|
||||
if self.execute_count == 1:
|
||||
return _FakeExecResult(thread=self.thread)
|
||||
return _FakeExecResult(scalars=self.models)
|
||||
|
||||
async def commit(self):
|
||||
self.commit_count += 1
|
||||
|
||||
|
||||
def _set_global_llm_configs(monkeypatch, config, configs: list[dict]):
|
||||
"""Patch the new global model catalog shape from compact legacy cfg fixtures."""
|
||||
connections = []
|
||||
models = []
|
||||
for cfg in configs:
|
||||
config_id = int(cfg["id"])
|
||||
connection_id = config_id - 100_000
|
||||
provider = cfg.get("provider") or cfg.get("litellm_provider")
|
||||
model_name = cfg["model_name"]
|
||||
connections.append(
|
||||
{
|
||||
"id": connection_id,
|
||||
"provider": provider,
|
||||
"scope": "GLOBAL",
|
||||
"enabled": True,
|
||||
}
|
||||
)
|
||||
models.append(
|
||||
{
|
||||
"id": config_id,
|
||||
"connection_id": connection_id,
|
||||
"model_id": model_name,
|
||||
"display_name": cfg.get("name") or model_name,
|
||||
"supports_chat": cfg.get("supports_chat", True),
|
||||
"supports_image_input": cfg.get("supports_image_input", True),
|
||||
"supports_tools": cfg.get("supports_tools", True),
|
||||
"supports_image_generation": cfg.get(
|
||||
"supports_image_generation", False
|
||||
),
|
||||
"capabilities_override": cfg.get("capabilities_override") or {},
|
||||
"billing_tier": cfg.get("billing_tier", "free"),
|
||||
"catalog": {
|
||||
"auto_pin_tier": cfg.get("auto_pin_tier"),
|
||||
"quality_score": cfg.get("quality_score")
|
||||
or cfg.get("quality_score_static"),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", configs)
|
||||
monkeypatch.setattr(config, "GLOBAL_CONNECTIONS", connections)
|
||||
monkeypatch.setattr(config, "GLOBAL_MODELS", models)
|
||||
|
||||
|
||||
def _thread(
|
||||
*,
|
||||
search_space_id: int = 10,
|
||||
|
|
@ -71,14 +155,19 @@ async def test_auto_first_turn_pins_one_model(monkeypatch):
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
_set_global_llm_configs(
|
||||
monkeypatch,
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"},
|
||||
{
|
||||
"id": -2,
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "gpt-free",
|
||||
"api_key": "k1",
|
||||
},
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENAI",
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "gpt-prem",
|
||||
"api_key": "k2",
|
||||
"billing_tier": "premium",
|
||||
|
|
@ -111,13 +200,13 @@ async def test_premium_eligible_auto_prefers_premium_over_free(monkeypatch):
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
_set_global_llm_configs(
|
||||
monkeypatch,
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -2,
|
||||
"provider": "OPENAI",
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "gpt-free",
|
||||
"api_key": "k1",
|
||||
"billing_tier": "free",
|
||||
|
|
@ -125,7 +214,7 @@ async def test_premium_eligible_auto_prefers_premium_over_free(monkeypatch):
|
|||
},
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENAI",
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "gpt-prem",
|
||||
"api_key": "k2",
|
||||
"billing_tier": "premium",
|
||||
|
|
@ -154,17 +243,19 @@ async def test_premium_eligible_auto_prefers_premium_over_free(monkeypatch):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch):
|
||||
async def test_premium_eligible_auto_uses_quality_pool_not_single_preferred_model(
|
||||
monkeypatch,
|
||||
):
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
_set_global_llm_configs(
|
||||
monkeypatch,
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "AZURE_OPENAI",
|
||||
"litellm_provider": "azure",
|
||||
"model_name": "gpt-5.1",
|
||||
"api_key": "k1",
|
||||
"billing_tier": "premium",
|
||||
|
|
@ -173,7 +264,7 @@ async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch):
|
|||
},
|
||||
{
|
||||
"id": -2,
|
||||
"provider": "AZURE_OPENAI",
|
||||
"litellm_provider": "azure",
|
||||
"model_name": "gpt-5.4",
|
||||
"api_key": "k2",
|
||||
"billing_tier": "premium",
|
||||
|
|
@ -182,12 +273,39 @@ async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch):
|
|||
},
|
||||
{
|
||||
"id": -3,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "openai/gpt-5.4",
|
||||
"litellm_provider": "anthropic",
|
||||
"model_name": "claude-opus",
|
||||
"api_key": "k3",
|
||||
"billing_tier": "premium",
|
||||
"auto_pin_tier": "B",
|
||||
"quality_score": 100,
|
||||
"auto_pin_tier": "A",
|
||||
"quality_score": 99,
|
||||
},
|
||||
{
|
||||
"id": -4,
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "gpt-5.3",
|
||||
"api_key": "k4",
|
||||
"billing_tier": "premium",
|
||||
"auto_pin_tier": "A",
|
||||
"quality_score": 98,
|
||||
},
|
||||
{
|
||||
"id": -5,
|
||||
"litellm_provider": "gemini",
|
||||
"model_name": "gemini-3-pro",
|
||||
"api_key": "k5",
|
||||
"billing_tier": "premium",
|
||||
"auto_pin_tier": "A",
|
||||
"quality_score": 97,
|
||||
},
|
||||
{
|
||||
"id": -6,
|
||||
"litellm_provider": "xai",
|
||||
"model_name": "grok-5",
|
||||
"api_key": "k6",
|
||||
"billing_tier": "premium",
|
||||
"auto_pin_tier": "A",
|
||||
"quality_score": 96,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
|
@ -207,7 +325,7 @@ async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch):
|
|||
user_id="00000000-0000-0000-0000-000000000001",
|
||||
selected_llm_config_id=0,
|
||||
)
|
||||
assert result.resolved_llm_config_id == -2
|
||||
assert result.resolved_llm_config_id in {-1, -3, -4, -5, -6}
|
||||
assert result.resolved_tier == "premium"
|
||||
|
||||
|
||||
|
|
@ -216,13 +334,13 @@ async def test_next_turn_reuses_existing_pin(monkeypatch):
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread(pinned_llm_config_id=-1))
|
||||
monkeypatch.setattr(
|
||||
_set_global_llm_configs(
|
||||
monkeypatch,
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENAI",
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "gpt-prem",
|
||||
"api_key": "k2",
|
||||
"billing_tier": "premium",
|
||||
|
|
@ -257,13 +375,13 @@ async def test_premium_eligible_auto_can_pin_premium(monkeypatch):
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
_set_global_llm_configs(
|
||||
monkeypatch,
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENAI",
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "gpt-prem",
|
||||
"api_key": "k2",
|
||||
"billing_tier": "premium",
|
||||
|
|
@ -295,20 +413,20 @@ async def test_premium_ineligible_auto_pins_free_only(monkeypatch):
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
_set_global_llm_configs(
|
||||
monkeypatch,
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -2,
|
||||
"provider": "OPENAI",
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "gpt-free",
|
||||
"api_key": "k1",
|
||||
"billing_tier": "free",
|
||||
},
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENAI",
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "gpt-prem",
|
||||
"api_key": "k2",
|
||||
"billing_tier": "premium",
|
||||
|
|
@ -340,20 +458,20 @@ async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch):
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread(pinned_llm_config_id=-1))
|
||||
monkeypatch.setattr(
|
||||
_set_global_llm_configs(
|
||||
monkeypatch,
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -2,
|
||||
"provider": "OPENAI",
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "gpt-free",
|
||||
"api_key": "k1",
|
||||
"billing_tier": "free",
|
||||
},
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENAI",
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "gpt-prem",
|
||||
"api_key": "k2",
|
||||
"billing_tier": "premium",
|
||||
|
|
@ -385,20 +503,20 @@ async def test_force_repin_free_switches_auto_premium_pin_to_free(monkeypatch):
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread(pinned_llm_config_id=-1))
|
||||
monkeypatch.setattr(
|
||||
_set_global_llm_configs(
|
||||
monkeypatch,
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -2,
|
||||
"provider": "OPENAI",
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "gpt-free",
|
||||
"api_key": "k1",
|
||||
"billing_tier": "free",
|
||||
},
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENAI",
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "gpt-prem",
|
||||
"api_key": "k2",
|
||||
"billing_tier": "premium",
|
||||
|
|
@ -433,11 +551,16 @@ async def test_explicit_user_model_change_clears_pin(monkeypatch):
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread(pinned_llm_config_id=-2))
|
||||
monkeypatch.setattr(
|
||||
_set_global_llm_configs(
|
||||
monkeypatch,
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"},
|
||||
{
|
||||
"id": -2,
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "gpt-free",
|
||||
"api_key": "k1",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -458,11 +581,16 @@ async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch):
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread(pinned_llm_config_id=-999))
|
||||
monkeypatch.setattr(
|
||||
_set_global_llm_configs(
|
||||
monkeypatch,
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"},
|
||||
{
|
||||
"id": -2,
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "gpt-free",
|
||||
"api_key": "k1",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -487,7 +615,7 @@ async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch):
|
|||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Quality-aware pin selection (Auto Fastest upgrade)
|
||||
# Quality-aware pin selection (Auto upgrade)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
|
|
@ -498,13 +626,13 @@ async def test_health_gated_config_is_excluded_from_selection(monkeypatch):
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
_set_global_llm_configs(
|
||||
monkeypatch,
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENROUTER",
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "venice/dead-model",
|
||||
"api_key": "k1",
|
||||
"billing_tier": "free",
|
||||
|
|
@ -514,7 +642,7 @@ async def test_health_gated_config_is_excluded_from_selection(monkeypatch):
|
|||
},
|
||||
{
|
||||
"id": -2,
|
||||
"provider": "OPENROUTER",
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "google/gemini-flash",
|
||||
"api_key": "k1",
|
||||
"billing_tier": "free",
|
||||
|
|
@ -550,13 +678,13 @@ async def test_tier_a_locks_first_premium_user_skips_or(monkeypatch):
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
_set_global_llm_configs(
|
||||
monkeypatch,
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "AZURE_OPENAI",
|
||||
"litellm_provider": "azure",
|
||||
"model_name": "gpt-5",
|
||||
"api_key": "k-yaml",
|
||||
"billing_tier": "premium",
|
||||
|
|
@ -566,7 +694,7 @@ async def test_tier_a_locks_first_premium_user_skips_or(monkeypatch):
|
|||
},
|
||||
{
|
||||
"id": -2,
|
||||
"provider": "OPENROUTER",
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "openai/gpt-5",
|
||||
"api_key": "k-or",
|
||||
"billing_tier": "premium",
|
||||
|
|
@ -602,13 +730,13 @@ async def test_tier_a_falls_through_to_or_when_a_pool_empty_for_user(monkeypatch
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
_set_global_llm_configs(
|
||||
monkeypatch,
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "AZURE_OPENAI",
|
||||
"litellm_provider": "azure",
|
||||
"model_name": "gpt-5",
|
||||
"api_key": "k-yaml",
|
||||
"billing_tier": "premium",
|
||||
|
|
@ -618,7 +746,7 @@ async def test_tier_a_falls_through_to_or_when_a_pool_empty_for_user(monkeypatch
|
|||
},
|
||||
{
|
||||
"id": -2,
|
||||
"provider": "OPENROUTER",
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "google/gemini-flash:free",
|
||||
"api_key": "k-or",
|
||||
"billing_tier": "free",
|
||||
|
|
@ -656,7 +784,7 @@ async def test_top_k_picks_only_high_score_models(monkeypatch):
|
|||
high_score_cfgs = [
|
||||
{
|
||||
"id": -i,
|
||||
"provider": "AZURE_OPENAI",
|
||||
"litellm_provider": "azure",
|
||||
"model_name": f"gpt-x-{i}",
|
||||
"api_key": "k",
|
||||
"billing_tier": "premium",
|
||||
|
|
@ -668,7 +796,7 @@ async def test_top_k_picks_only_high_score_models(monkeypatch):
|
|||
]
|
||||
low_score_trap = {
|
||||
"id": -99,
|
||||
"provider": "AZURE_OPENAI",
|
||||
"litellm_provider": "azure",
|
||||
"model_name": "tiny-legacy",
|
||||
"api_key": "k",
|
||||
"billing_tier": "premium",
|
||||
|
|
@ -676,9 +804,9 @@ async def test_top_k_picks_only_high_score_models(monkeypatch):
|
|||
"quality_score": 10,
|
||||
"health_gated": False,
|
||||
}
|
||||
monkeypatch.setattr(
|
||||
_set_global_llm_configs(
|
||||
monkeypatch,
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[*high_score_cfgs, low_score_trap],
|
||||
)
|
||||
|
||||
|
|
@ -723,13 +851,13 @@ async def test_pin_reuse_survives_health_gating_for_existing_pin(monkeypatch):
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread(pinned_llm_config_id=-1))
|
||||
monkeypatch.setattr(
|
||||
_set_global_llm_configs(
|
||||
monkeypatch,
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENROUTER",
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "venice/dead-model",
|
||||
"api_key": "k",
|
||||
"billing_tier": "premium",
|
||||
|
|
@ -739,7 +867,7 @@ async def test_pin_reuse_survives_health_gating_for_existing_pin(monkeypatch):
|
|||
},
|
||||
{
|
||||
"id": -2,
|
||||
"provider": "AZURE_OPENAI",
|
||||
"litellm_provider": "azure",
|
||||
"model_name": "gpt-5",
|
||||
"api_key": "k",
|
||||
"billing_tier": "premium",
|
||||
|
|
@ -775,13 +903,13 @@ async def test_pin_reuse_regression_existing_healthy_pin(monkeypatch):
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread(pinned_llm_config_id=-1))
|
||||
monkeypatch.setattr(
|
||||
_set_global_llm_configs(
|
||||
monkeypatch,
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "AZURE_OPENAI",
|
||||
"litellm_provider": "azure",
|
||||
"model_name": "gpt-5",
|
||||
"api_key": "k",
|
||||
"billing_tier": "premium",
|
||||
|
|
@ -791,7 +919,7 @@ async def test_pin_reuse_regression_existing_healthy_pin(monkeypatch):
|
|||
},
|
||||
{
|
||||
"id": -2,
|
||||
"provider": "AZURE_OPENAI",
|
||||
"litellm_provider": "azure",
|
||||
"model_name": "gpt-5-pro",
|
||||
"api_key": "k",
|
||||
"billing_tier": "premium",
|
||||
|
|
@ -833,13 +961,13 @@ async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch):
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread(pinned_llm_config_id=-1))
|
||||
monkeypatch.setattr(
|
||||
_set_global_llm_configs(
|
||||
monkeypatch,
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENROUTER",
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "google/gemma-4-26b-a4b-it:free",
|
||||
"api_key": "k",
|
||||
"billing_tier": "free",
|
||||
|
|
@ -849,7 +977,7 @@ async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch):
|
|||
},
|
||||
{
|
||||
"id": -2,
|
||||
"provider": "OPENROUTER",
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "google/gemini-2.5-flash:free",
|
||||
"api_key": "k",
|
||||
"billing_tier": "free",
|
||||
|
|
@ -881,18 +1009,86 @@ async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch):
|
|||
assert result.from_existing_pin is False
|
||||
|
||||
|
||||
def test_mark_runtime_cooldown_writes_shared_redis(monkeypatch):
|
||||
import app.services.auto_model_pin_service as svc
|
||||
|
||||
mark_runtime_cooldown(-9, reason="provider_rate_limited", cooldown_seconds=123)
|
||||
|
||||
redis_client = svc._runtime_cooldown_redis
|
||||
assert redis_client.values["auto:cooldown:llm:-9"] == "provider_rate_limited"
|
||||
assert redis_client.ttls["auto:cooldown:llm:-9"] == 123
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shared_runtime_cooldown_blocks_pin_across_workers(monkeypatch):
|
||||
"""A Redis cooldown written by another worker should invalidate local pins."""
|
||||
import app.services.auto_model_pin_service as svc
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread(pinned_llm_config_id=-1))
|
||||
_set_global_llm_configs(
|
||||
monkeypatch,
|
||||
config,
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "google/gemma-4-26b-a4b-it:free",
|
||||
"api_key": "k",
|
||||
"billing_tier": "free",
|
||||
"auto_pin_tier": "C",
|
||||
"quality_score": 90,
|
||||
"health_gated": False,
|
||||
},
|
||||
{
|
||||
"id": -2,
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "google/gemini-2.5-flash:free",
|
||||
"api_key": "k",
|
||||
"billing_tier": "free",
|
||||
"auto_pin_tier": "C",
|
||||
"quality_score": 80,
|
||||
"health_gated": False,
|
||||
},
|
||||
],
|
||||
)
|
||||
svc._runtime_cooldown_redis.set(
|
||||
"auto:cooldown:llm:-1",
|
||||
"provider_rate_limited",
|
||||
ex=600,
|
||||
)
|
||||
|
||||
async def _blocked(*_args, **_kwargs):
|
||||
return _FakeQuotaResult(allowed=False)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage",
|
||||
_blocked,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id="00000000-0000-0000-0000-000000000001",
|
||||
selected_llm_config_id=0,
|
||||
)
|
||||
assert result.resolved_llm_config_id == -2
|
||||
assert result.from_existing_pin is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clearing_runtime_cooldown_restores_pin_reuse(monkeypatch):
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread(pinned_llm_config_id=-1))
|
||||
monkeypatch.setattr(
|
||||
_set_global_llm_configs(
|
||||
monkeypatch,
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENROUTER",
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "google/gemma-4-26b-a4b-it:free",
|
||||
"api_key": "k",
|
||||
"billing_tier": "free",
|
||||
|
|
@ -931,13 +1127,13 @@ async def test_auto_pin_repin_excludes_previous_config_on_runtime_retry(monkeypa
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread(pinned_llm_config_id=-1))
|
||||
monkeypatch.setattr(
|
||||
_set_global_llm_configs(
|
||||
monkeypatch,
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENROUTER",
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "google/gemma-4-26b-a4b-it:free",
|
||||
"api_key": "k",
|
||||
"billing_tier": "free",
|
||||
|
|
@ -947,7 +1143,7 @@ async def test_auto_pin_repin_excludes_previous_config_on_runtime_retry(monkeypa
|
|||
},
|
||||
{
|
||||
"id": -2,
|
||||
"provider": "OPENROUTER",
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "google/gemini-2.5-flash:free",
|
||||
"api_key": "k",
|
||||
"billing_tier": "free",
|
||||
|
|
|
|||
|
|
@ -45,8 +45,9 @@ class _FakeQuotaResult:
|
|||
|
||||
|
||||
class _FakeExecResult:
|
||||
def __init__(self, thread):
|
||||
def __init__(self, *, thread=None, scalars=None):
|
||||
self._thread = thread
|
||||
self._scalars = scalars or []
|
||||
|
||||
def unique(self):
|
||||
return self
|
||||
|
|
@ -54,14 +55,21 @@ class _FakeExecResult:
|
|||
def scalar_one_or_none(self):
|
||||
return self._thread
|
||||
|
||||
def scalars(self):
|
||||
return SimpleNamespace(all=lambda: self._scalars)
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, thread):
|
||||
self.thread = thread
|
||||
self.commit_count = 0
|
||||
self.execute_count = 0
|
||||
|
||||
async def execute(self, _stmt):
|
||||
return _FakeExecResult(self.thread)
|
||||
self.execute_count += 1
|
||||
if self.execute_count == 1:
|
||||
return _FakeExecResult(thread=self.thread)
|
||||
return _FakeExecResult(scalars=[])
|
||||
|
||||
async def commit(self):
|
||||
self.commit_count += 1
|
||||
|
|
@ -71,10 +79,64 @@ def _thread(*, pinned: int | None = None):
|
|||
return SimpleNamespace(id=1, search_space_id=10, pinned_llm_config_id=pinned)
|
||||
|
||||
|
||||
def _set_global_llm_configs(monkeypatch, config, configs: list[dict]):
|
||||
from app.services.provider_capabilities import derive_supports_image_input
|
||||
|
||||
connections = []
|
||||
models = []
|
||||
for cfg in configs:
|
||||
config_id = int(cfg["id"])
|
||||
connection_id = config_id - 100_000
|
||||
provider = cfg.get("provider") or cfg.get("litellm_provider")
|
||||
model_name = cfg["model_name"]
|
||||
if "supports_image_input" not in cfg:
|
||||
litellm_params = cfg.get("litellm_params") or {}
|
||||
base_model = (
|
||||
litellm_params.get("base_model")
|
||||
if isinstance(litellm_params, dict)
|
||||
else None
|
||||
)
|
||||
cfg["supports_image_input"] = derive_supports_image_input(
|
||||
provider=provider,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
custom_provider=cfg.get("custom_provider"),
|
||||
)
|
||||
connections.append(
|
||||
{
|
||||
"id": connection_id,
|
||||
"provider": provider,
|
||||
"scope": "GLOBAL",
|
||||
"enabled": True,
|
||||
}
|
||||
)
|
||||
model = {
|
||||
"id": config_id,
|
||||
"connection_id": connection_id,
|
||||
"model_id": model_name,
|
||||
"display_name": cfg.get("name") or model_name,
|
||||
"supports_chat": cfg.get("supports_chat", True),
|
||||
"supports_tools": cfg.get("supports_tools", True),
|
||||
"supports_image_generation": cfg.get("supports_image_generation", False),
|
||||
"capabilities_override": cfg.get("capabilities_override") or {},
|
||||
"billing_tier": cfg.get("billing_tier", "free"),
|
||||
"catalog": {
|
||||
"auto_pin_tier": cfg.get("auto_pin_tier"),
|
||||
"quality_score": cfg.get("quality_score"),
|
||||
},
|
||||
"supports_image_input": cfg["supports_image_input"],
|
||||
}
|
||||
models.append(model)
|
||||
|
||||
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", configs)
|
||||
monkeypatch.setattr(config, "GLOBAL_CONNECTIONS", connections)
|
||||
monkeypatch.setattr(config, "GLOBAL_MODELS", models)
|
||||
|
||||
|
||||
def _vision_cfg(id_: int, *, tier: str = "free", quality: int = 80) -> dict:
|
||||
return {
|
||||
"id": id_,
|
||||
"provider": "OPENAI",
|
||||
"litellm_provider": "openai",
|
||||
"model_name": f"vision-{id_}",
|
||||
"api_key": "k",
|
||||
"billing_tier": tier,
|
||||
|
|
@ -87,7 +149,7 @@ def _vision_cfg(id_: int, *, tier: str = "free", quality: int = 80) -> dict:
|
|||
def _text_only_cfg(id_: int, *, tier: str = "free", quality: int = 90) -> dict:
|
||||
return {
|
||||
"id": id_,
|
||||
"provider": "OPENAI",
|
||||
"litellm_provider": "openai",
|
||||
"model_name": f"text-{id_}",
|
||||
"api_key": "k",
|
||||
"billing_tier": tier,
|
||||
|
|
@ -108,11 +170,7 @@ async def test_image_turn_filters_out_text_only_candidates(monkeypatch):
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[_text_only_cfg(-1), _vision_cfg(-2)],
|
||||
)
|
||||
_set_global_llm_configs(monkeypatch, config, [_text_only_cfg(-1), _vision_cfg(-2)])
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage",
|
||||
_premium_allowed,
|
||||
|
|
@ -140,11 +198,7 @@ async def test_image_turn_force_repins_stale_text_only_pin(monkeypatch):
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread(pinned=-1))
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[_text_only_cfg(-1), _vision_cfg(-2)],
|
||||
)
|
||||
_set_global_llm_configs(monkeypatch, config, [_text_only_cfg(-1), _vision_cfg(-2)])
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage",
|
||||
_premium_allowed,
|
||||
|
|
@ -172,9 +226,9 @@ async def test_image_turn_reuses_existing_vision_pin(monkeypatch):
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread(pinned=-2))
|
||||
monkeypatch.setattr(
|
||||
_set_global_llm_configs(
|
||||
monkeypatch,
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[_text_only_cfg(-1), _vision_cfg(-2), _vision_cfg(-3, quality=70)],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
|
|
@ -203,10 +257,8 @@ async def test_image_turn_with_no_vision_candidates_raises(monkeypatch):
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[_text_only_cfg(-1), _text_only_cfg(-2)],
|
||||
_set_global_llm_configs(
|
||||
monkeypatch, config, [_text_only_cfg(-1), _text_only_cfg(-2)]
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage",
|
||||
|
|
@ -231,11 +283,7 @@ async def test_non_image_turn_keeps_text_only_in_pool(monkeypatch):
|
|||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[_text_only_cfg(-1)],
|
||||
)
|
||||
_set_global_llm_configs(monkeypatch, config, [_text_only_cfg(-1)])
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage",
|
||||
_premium_allowed,
|
||||
|
|
@ -261,7 +309,7 @@ async def test_image_turn_unannotated_cfg_resolves_via_helper(monkeypatch):
|
|||
session = _FakeSession(_thread())
|
||||
cfg_unannotated_vision = {
|
||||
"id": -2,
|
||||
"provider": "OPENAI",
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "gpt-4o", # known vision model in LiteLLM map
|
||||
"api_key": "k",
|
||||
"billing_tier": "free",
|
||||
|
|
@ -269,7 +317,7 @@ async def test_image_turn_unannotated_cfg_resolves_via_helper(monkeypatch):
|
|||
"quality_score": 80,
|
||||
# NOTE: no supports_image_input key
|
||||
}
|
||||
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [cfg_unannotated_vision])
|
||||
_set_global_llm_configs(monkeypatch, config, [cfg_unannotated_vision])
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage",
|
||||
_premium_allowed,
|
||||
|
|
|
|||
|
|
@ -1,19 +1,4 @@
|
|||
"""Defense-in-depth: image-gen call sites must not let an empty
|
||||
``api_base`` fall through to LiteLLM's module-global ``litellm.api_base``.
|
||||
|
||||
The bug repro: an OpenRouter image-gen config ships
|
||||
``api_base=""``. The pre-fix call site in
|
||||
``image_generation_routes._execute_image_generation`` did
|
||||
``if cfg.get("api_base"): kwargs["api_base"] = cfg["api_base"]`` which
|
||||
silently dropped the empty string. LiteLLM then fell back to
|
||||
``litellm.api_base`` (commonly inherited from ``AZURE_OPENAI_ENDPOINT``)
|
||||
and OpenRouter's ``image_generation/transformation`` appended
|
||||
``/chat/completions`` to it → 404 ``Resource not found``.
|
||||
|
||||
This test pins the post-fix behaviour: with an empty ``api_base`` in
|
||||
the config, the call site MUST set ``api_base`` to OpenRouter's public
|
||||
URL instead of leaving it unset.
|
||||
"""
|
||||
"""Image-gen call sites must pass each config's explicit ``api_base``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
|
@ -26,22 +11,23 @@ pytestmark = pytest.mark.unit
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_openrouter_image_gen_sets_api_base_when_config_empty():
|
||||
"""The global-config branch (``config_id < 0``) of
|
||||
``_execute_image_generation`` must apply the resolver and pin
|
||||
``api_base`` to OpenRouter when the config ships an empty string.
|
||||
"""
|
||||
async def test_global_openrouter_image_gen_sets_explicit_api_base():
|
||||
"""The global-config branch forwards the explicit OpenRouter base."""
|
||||
from app.routes import image_generation_routes
|
||||
|
||||
cfg = {
|
||||
global_model = {
|
||||
"id": -20_001,
|
||||
"name": "GPT Image 1 (OpenRouter)",
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "openai/gpt-image-1",
|
||||
"connection_id": -101,
|
||||
"model_id": "openai/gpt-image-1",
|
||||
"supports_image_generation": True,
|
||||
"capabilities_override": {},
|
||||
}
|
||||
global_connection = {
|
||||
"id": -101,
|
||||
"provider": "openrouter",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "", # the original bug shape
|
||||
"api_version": None,
|
||||
"litellm_params": {},
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
"extra": {},
|
||||
}
|
||||
|
||||
captured: dict = {}
|
||||
|
|
@ -51,7 +37,7 @@ async def test_global_openrouter_image_gen_sets_api_base_when_config_empty():
|
|||
return MagicMock(model_dump=lambda: {"data": []}, _hidden_params={})
|
||||
|
||||
image_gen = MagicMock()
|
||||
image_gen.image_generation_config_id = cfg["id"]
|
||||
image_gen.image_gen_model_id = global_model["id"]
|
||||
image_gen.prompt = "test"
|
||||
image_gen.n = 1
|
||||
image_gen.quality = None
|
||||
|
|
@ -61,14 +47,19 @@ async def test_global_openrouter_image_gen_sets_api_base_when_config_empty():
|
|||
image_gen.model = None
|
||||
|
||||
search_space = MagicMock()
|
||||
search_space.image_generation_config_id = cfg["id"]
|
||||
search_space.image_gen_model_id = global_model["id"]
|
||||
session = MagicMock()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
image_generation_routes,
|
||||
"_get_global_image_gen_config",
|
||||
return_value=cfg,
|
||||
"_get_global_model",
|
||||
return_value=global_model,
|
||||
),
|
||||
patch.object(
|
||||
image_generation_routes,
|
||||
"_get_global_connection",
|
||||
return_value=global_connection,
|
||||
),
|
||||
patch.object(
|
||||
image_generation_routes,
|
||||
|
|
@ -80,30 +71,31 @@ async def test_global_openrouter_image_gen_sets_api_base_when_config_empty():
|
|||
session=session, image_gen=image_gen, search_space=search_space
|
||||
)
|
||||
|
||||
# The whole point of the fix: even with empty ``api_base`` in the
|
||||
# config, we forward OpenRouter's public URL so the call doesn't
|
||||
# inherit an Azure endpoint.
|
||||
assert captured.get("api_base") == "https://openrouter.ai/api/v1"
|
||||
assert captured["model"] == "openrouter/openai/gpt-image-1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_image_tool_global_sets_api_base_when_config_empty():
|
||||
"""Same defense at the agent tool entry point — both surfaces share
|
||||
async def test_generate_image_tool_global_sets_explicit_api_base():
|
||||
"""Same explicit-base behavior at the agent tool entry point — both surfaces share
|
||||
the same OpenRouter config payloads."""
|
||||
from app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.tools import (
|
||||
generate_image as gi_module,
|
||||
)
|
||||
|
||||
cfg = {
|
||||
global_model = {
|
||||
"id": -20_001,
|
||||
"name": "GPT Image 1 (OpenRouter)",
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "openai/gpt-image-1",
|
||||
"connection_id": -101,
|
||||
"model_id": "openai/gpt-image-1",
|
||||
"supports_image_generation": True,
|
||||
"capabilities_override": {},
|
||||
}
|
||||
global_connection = {
|
||||
"id": -101,
|
||||
"provider": "openrouter",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "",
|
||||
"api_version": None,
|
||||
"litellm_params": {},
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
"extra": {},
|
||||
}
|
||||
|
||||
captured: dict = {}
|
||||
|
|
@ -119,7 +111,7 @@ async def test_generate_image_tool_global_sets_api_base_when_config_empty():
|
|||
|
||||
search_space = MagicMock()
|
||||
search_space.id = 1
|
||||
search_space.image_generation_config_id = cfg["id"]
|
||||
search_space.image_gen_model_id = global_model["id"]
|
||||
|
||||
session_cm = AsyncMock()
|
||||
session = AsyncMock()
|
||||
|
|
@ -142,7 +134,10 @@ async def test_generate_image_tool_global_sets_api_base_when_config_empty():
|
|||
|
||||
with (
|
||||
patch.object(gi_module, "shielded_async_session", return_value=session_cm),
|
||||
patch.object(gi_module, "_get_global_image_gen_config", return_value=cfg),
|
||||
patch.object(gi_module, "_get_global_model", return_value=global_model),
|
||||
patch.object(
|
||||
gi_module, "_get_global_connection", return_value=global_connection
|
||||
),
|
||||
patch.object(
|
||||
gi_module, "aimage_generation", side_effect=fake_aimage_generation
|
||||
),
|
||||
|
|
@ -171,20 +166,16 @@ async def test_generate_image_tool_global_sets_api_base_when_config_empty():
|
|||
assert captured["model"] == "openrouter/openai/gpt-image-1"
|
||||
|
||||
|
||||
def test_image_gen_router_deployment_sets_api_base_when_config_empty():
|
||||
"""The Auto-mode router pool must also resolve ``api_base`` when an
|
||||
OpenRouter config ships an empty string. The deployment dict is fed
|
||||
straight to ``litellm.Router``, so a missing ``api_base`` would
|
||||
leak the same way as the direct call sites.
|
||||
"""
|
||||
def test_image_gen_router_deployment_sets_explicit_api_base():
|
||||
"""The Auto-mode router pool carries explicit api_base into deployments."""
|
||||
from app.services.image_gen_router_service import ImageGenRouterService
|
||||
|
||||
deployment = ImageGenRouterService._config_to_deployment(
|
||||
{
|
||||
"model_name": "openai/gpt-image-1",
|
||||
"provider": "OPENROUTER",
|
||||
"litellm_provider": "openrouter",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "",
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
}
|
||||
)
|
||||
assert deployment is not None
|
||||
|
|
|
|||
|
|
@ -25,10 +25,10 @@ def _fake_yaml_config(
|
|||
return {
|
||||
"id": id,
|
||||
"name": f"yaml-{id}",
|
||||
"provider": "OPENAI",
|
||||
"litellm_provider": "openai",
|
||||
"model_name": model_name,
|
||||
"api_key": "sk-test",
|
||||
"api_base": "",
|
||||
"api_base": "https://api.openai.com/v1",
|
||||
"billing_tier": billing_tier,
|
||||
"rpm": 100,
|
||||
"tpm": 100_000,
|
||||
|
|
@ -54,10 +54,10 @@ def _fake_openrouter_config(
|
|||
return {
|
||||
"id": id,
|
||||
"name": f"or-{id}",
|
||||
"provider": "OPENROUTER",
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": model_name,
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "",
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
"billing_tier": billing_tier,
|
||||
"rpm": 20 if billing_tier == "free" else 200,
|
||||
"tpm": 100_000 if billing_tier == "free" else 1_000_000,
|
||||
|
|
@ -217,10 +217,64 @@ def test_auto_model_pin_candidates_include_dynamic_openrouter():
|
|||
model_name="meta-llama/llama-3.3-70b:free",
|
||||
billing_tier="free",
|
||||
)
|
||||
original = config.GLOBAL_LLM_CONFIGS
|
||||
global_connections = [
|
||||
{
|
||||
"id": -110_001,
|
||||
"provider": "openrouter",
|
||||
"scope": "GLOBAL",
|
||||
"enabled": True,
|
||||
},
|
||||
{
|
||||
"id": -110_002,
|
||||
"provider": "openrouter",
|
||||
"scope": "GLOBAL",
|
||||
"enabled": True,
|
||||
},
|
||||
]
|
||||
global_models = [
|
||||
{
|
||||
"id": or_premium["id"],
|
||||
"connection_id": -110_001,
|
||||
"model_id": or_premium["model_name"],
|
||||
"display_name": or_premium["name"],
|
||||
"supports_chat": True,
|
||||
"supports_image_input": True,
|
||||
"supports_tools": True,
|
||||
"supports_image_generation": False,
|
||||
"capabilities_override": {},
|
||||
"billing_tier": or_premium["billing_tier"],
|
||||
"catalog": {
|
||||
"auto_pin_tier": "A",
|
||||
"quality_score": 50,
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": or_free["id"],
|
||||
"connection_id": -110_002,
|
||||
"model_id": or_free["model_name"],
|
||||
"display_name": or_free["name"],
|
||||
"supports_chat": True,
|
||||
"supports_image_input": True,
|
||||
"supports_tools": True,
|
||||
"supports_image_generation": False,
|
||||
"capabilities_override": {},
|
||||
"billing_tier": or_free["billing_tier"],
|
||||
"catalog": {
|
||||
"auto_pin_tier": "A",
|
||||
"quality_score": 50,
|
||||
},
|
||||
},
|
||||
]
|
||||
original_configs = config.GLOBAL_LLM_CONFIGS
|
||||
original_connections = config.GLOBAL_CONNECTIONS
|
||||
original_models = config.GLOBAL_MODELS
|
||||
try:
|
||||
config.GLOBAL_LLM_CONFIGS = [or_premium, or_free]
|
||||
config.GLOBAL_CONNECTIONS = global_connections
|
||||
config.GLOBAL_MODELS = global_models
|
||||
candidate_ids = {c["id"] for c in _global_candidates()}
|
||||
assert candidate_ids == {-10_001, -10_002}
|
||||
finally:
|
||||
config.GLOBAL_LLM_CONFIGS = original
|
||||
config.GLOBAL_LLM_CONFIGS = original_configs
|
||||
config.GLOBAL_CONNECTIONS = original_connections
|
||||
config.GLOBAL_MODELS = original_models
|
||||
|
|
|
|||
|
|
@ -0,0 +1,78 @@
|
|||
from app.services.global_model_catalog import materialize_global_model_catalog
|
||||
from app.services.model_resolver import ensure_v1, to_litellm
|
||||
|
||||
|
||||
def test_openai_compatible_resolver_uses_explicit_api_base() -> None:
|
||||
model, kwargs = to_litellm(
|
||||
{
|
||||
"protocol": "OPENAI_COMPATIBLE",
|
||||
"provider": "openai",
|
||||
"base_url": "http://host.docker.internal:1234/v1",
|
||||
"api_key": "local-key",
|
||||
"extra": {},
|
||||
},
|
||||
"qwen/qwen3",
|
||||
)
|
||||
|
||||
assert model == "openai/qwen/qwen3"
|
||||
assert kwargs["api_base"] == "http://host.docker.internal:1234/v1"
|
||||
assert kwargs["api_key"] == "local-key"
|
||||
assert ensure_v1("http://example.com/v1") == "http://example.com/v1"
|
||||
|
||||
|
||||
def test_ollama_resolver_uses_native_api_base() -> None:
|
||||
model, kwargs = to_litellm(
|
||||
{
|
||||
"protocol": "OLLAMA",
|
||||
"provider": "ollama_chat",
|
||||
"base_url": "http://host.docker.internal:11434",
|
||||
"api_key": None,
|
||||
"extra": {},
|
||||
},
|
||||
"llama3.2",
|
||||
)
|
||||
|
||||
assert model == "ollama_chat/llama3.2"
|
||||
assert kwargs["api_base"] == "http://host.docker.internal:11434"
|
||||
|
||||
|
||||
def test_global_materialization_preserves_tier_and_keeps_key_server_side() -> None:
|
||||
connections, models = materialize_global_model_catalog(
|
||||
chat_configs=[
|
||||
{
|
||||
"id": -101,
|
||||
"name": "OpenRouter Free",
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "meta-llama/llama-3.1-8b-instruct:free",
|
||||
"api_key": "sk-global-secret",
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
"billing_tier": "free",
|
||||
"anonymous_enabled": True,
|
||||
"seo_enabled": True,
|
||||
"rpm": 10,
|
||||
"tpm": 1000,
|
||||
},
|
||||
{
|
||||
"id": -102,
|
||||
"name": "OpenRouter Premium",
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "anthropic/claude-sonnet-4",
|
||||
"api_key": "sk-global-secret",
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
"billing_tier": "premium",
|
||||
},
|
||||
],
|
||||
image_configs=[],
|
||||
)
|
||||
|
||||
assert len(connections) == 1
|
||||
assert connections[0]["api_key"] == "sk-global-secret"
|
||||
assert {model["billing_tier"] for model in models} == {"free", "premium"}
|
||||
assert models[0]["catalog"]["anonymous_enabled"] is True
|
||||
assert models[0]["catalog"]["rpm"] == 10
|
||||
|
||||
public_connections = [
|
||||
{key: value for key, value in connection.items() if key != "api_key"}
|
||||
for connection in connections
|
||||
]
|
||||
assert "sk-" not in repr(public_connections)
|
||||
|
|
@ -217,7 +217,7 @@ def test_generate_configs_drops_non_text_and_non_tool_models():
|
|||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _generate_image_gen_configs / _generate_vision_llm_configs
|
||||
# _generate_image_gen_configs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
|
|
@ -263,18 +263,15 @@ def test_generate_image_gen_configs_filters_by_image_output():
|
|||
# Each config must carry ``billing_tier`` for routing in image_generation_routes.
|
||||
for c in cfgs:
|
||||
assert c["billing_tier"] in {"free", "premium"}
|
||||
assert c["provider"] == "OPENROUTER"
|
||||
assert c["provider"] == "openrouter"
|
||||
assert c[_OPENROUTER_DYNAMIC_MARKER] is True
|
||||
# Defense-in-depth: emit the OpenRouter base URL at source so a
|
||||
# downstream call site that forgets ``resolve_api_base`` still
|
||||
# doesn't 404 against an inherited Azure endpoint.
|
||||
# Emit the OpenRouter base URL at source so every call path passes an
|
||||
# explicit api_base and cannot inherit a process-global endpoint.
|
||||
assert c["api_base"] == "https://openrouter.ai/api/v1"
|
||||
|
||||
|
||||
def test_generate_image_gen_configs_assigns_image_id_offset():
|
||||
"""Image configs use a different id_offset (-20000) so their negative
|
||||
IDs don't collide with chat configs (-10000) or vision configs (-30000).
|
||||
"""
|
||||
"""Image configs use their own id_offset (-20000)."""
|
||||
from app.services.openrouter_integration_service import (
|
||||
_generate_image_gen_configs,
|
||||
)
|
||||
|
|
@ -291,90 +288,3 @@ def test_generate_image_gen_configs_assigns_image_id_offset():
|
|||
cfgs = _generate_image_gen_configs(raw, dict(_SETTINGS_BASE))
|
||||
assert all(c["id"] < -20_000 + 1 for c in cfgs)
|
||||
assert all(c["id"] > -29_000_000 for c in cfgs)
|
||||
|
||||
|
||||
def test_generate_vision_llm_configs_filters_by_image_input_text_output():
|
||||
"""Vision LLMs must accept image input AND emit text — pure image-gen
|
||||
(no text out) and text-only (no image in) models are excluded.
|
||||
"""
|
||||
from app.services.openrouter_integration_service import (
|
||||
_generate_vision_llm_configs,
|
||||
)
|
||||
|
||||
raw = [
|
||||
# GPT-4o: vision LLM (image in, text out) — must emit.
|
||||
{
|
||||
"id": "openai/gpt-4o",
|
||||
"architecture": {
|
||||
"input_modalities": ["text", "image"],
|
||||
"output_modalities": ["text"],
|
||||
},
|
||||
"context_length": 128_000,
|
||||
"pricing": {"prompt": "0.000005", "completion": "0.000015"},
|
||||
},
|
||||
# Pure image generator — image *output*, no text out. Must NOT emit.
|
||||
{
|
||||
"id": "openai/gpt-image-1",
|
||||
"architecture": {
|
||||
"input_modalities": ["text"],
|
||||
"output_modalities": ["image"],
|
||||
},
|
||||
"context_length": 4_000,
|
||||
"pricing": {"prompt": "0", "completion": "0"},
|
||||
},
|
||||
# Pure text model (no image in). Must NOT emit.
|
||||
{
|
||||
"id": "anthropic/claude-3-haiku",
|
||||
"architecture": {
|
||||
"input_modalities": ["text"],
|
||||
"output_modalities": ["text"],
|
||||
},
|
||||
"context_length": 200_000,
|
||||
"pricing": {"prompt": "0.000001", "completion": "0.000005"},
|
||||
},
|
||||
]
|
||||
|
||||
cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE))
|
||||
names = {c["model_name"] for c in cfgs}
|
||||
assert names == {"openai/gpt-4o"}
|
||||
|
||||
cfg = cfgs[0]
|
||||
assert cfg["billing_tier"] == "premium"
|
||||
# Pricing carried inline so pricing_registration can register vision
|
||||
# under ``openrouter/openai/gpt-4o`` even if the chat catalogue cache
|
||||
# is cleared.
|
||||
assert cfg["input_cost_per_token"] == pytest.approx(5e-6)
|
||||
assert cfg["output_cost_per_token"] == pytest.approx(15e-6)
|
||||
assert cfg[_OPENROUTER_DYNAMIC_MARKER] is True
|
||||
# Defense-in-depth: emit the OpenRouter base URL at source so a
|
||||
# downstream call site that forgets ``resolve_api_base`` still
|
||||
# doesn't inherit an Azure endpoint.
|
||||
assert cfg["api_base"] == "https://openrouter.ai/api/v1"
|
||||
|
||||
|
||||
def test_generate_vision_llm_configs_drops_chat_only_filters():
|
||||
"""A small-context vision model that doesn't advertise tool calling is
|
||||
still a valid vision LLM for "describe this image" prompts. The chat
|
||||
filters (``supports_tool_calling``, ``has_sufficient_context``) must
|
||||
NOT be applied to vision emission.
|
||||
"""
|
||||
from app.services.openrouter_integration_service import (
|
||||
_generate_vision_llm_configs,
|
||||
)
|
||||
|
||||
raw = [
|
||||
{
|
||||
"id": "tiny/vision-mini",
|
||||
"architecture": {
|
||||
"input_modalities": ["text", "image"],
|
||||
"output_modalities": ["text"],
|
||||
},
|
||||
"supported_parameters": [], # no tools
|
||||
"context_length": 4_000, # well below MIN_CONTEXT_LENGTH
|
||||
"pricing": {"prompt": "0.0000001", "completion": "0.0000005"},
|
||||
}
|
||||
]
|
||||
|
||||
cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE))
|
||||
assert len(cfgs) == 1
|
||||
assert cfgs[0]["model_name"] == "tiny/vision-mini"
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ def _or_cfg(
|
|||
) -> dict:
|
||||
return {
|
||||
"id": cid,
|
||||
"provider": "OPENROUTER",
|
||||
"provider": "openrouter",
|
||||
"model_name": model_name,
|
||||
"billing_tier": tier,
|
||||
"auto_pin_tier": "B" if tier == "premium" else "C",
|
||||
|
|
@ -144,7 +144,7 @@ async def test_enrich_health_only_touches_or_provider(monkeypatch):
|
|||
"""YAML cfgs that aren't OPENROUTER must be skipped entirely."""
|
||||
yaml_cfg = {
|
||||
"id": -1,
|
||||
"provider": "AZURE_OPENAI",
|
||||
"litellm_provider": "azure",
|
||||
"model_name": "gpt-5",
|
||||
"billing_tier": "premium",
|
||||
"auto_pin_tier": "A",
|
||||
|
|
@ -313,7 +313,7 @@ async def test_enrich_health_no_or_cfgs_is_noop(monkeypatch):
|
|||
"""When the catalogue has no OR cfgs at all, no HTTP calls fire."""
|
||||
yaml_cfg: dict[str, Any] = {
|
||||
"id": -1,
|
||||
"provider": "AZURE_OPENAI",
|
||||
"litellm_provider": "azure",
|
||||
"model_name": "gpt-5",
|
||||
"billing_tier": "premium",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -186,7 +186,7 @@ def test_openrouter_models_register_under_aliases(monkeypatch):
|
|||
[
|
||||
{
|
||||
"id": 1,
|
||||
"provider": "OPENROUTER",
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "anthropic/claude-3-5-sonnet",
|
||||
}
|
||||
],
|
||||
|
|
@ -228,7 +228,7 @@ def test_yaml_override_registers_under_alias_set(monkeypatch):
|
|||
[
|
||||
{
|
||||
"id": 1,
|
||||
"provider": "AZURE_OPENAI",
|
||||
"litellm_provider": "azure",
|
||||
"model_name": "gpt-5.4",
|
||||
"litellm_params": {
|
||||
"base_model": "gpt-5.4",
|
||||
|
|
@ -243,7 +243,6 @@ def test_yaml_override_registers_under_alias_set(monkeypatch):
|
|||
|
||||
keys = spy.all_keys
|
||||
assert "gpt-5.4" in keys
|
||||
assert "azure_openai/gpt-5.4" in keys
|
||||
assert "azure/gpt-5.4" in keys
|
||||
|
||||
payload = spy.calls[0]
|
||||
|
|
@ -271,7 +270,7 @@ def test_no_override_means_no_registration(monkeypatch):
|
|||
[
|
||||
{
|
||||
"id": 1,
|
||||
"provider": "OPENAI",
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"litellm_params": {"base_model": "gpt-4o"},
|
||||
}
|
||||
|
|
@ -302,7 +301,7 @@ def test_openrouter_skipped_when_pricing_missing(monkeypatch):
|
|||
[
|
||||
{
|
||||
"id": 1,
|
||||
"provider": "OPENROUTER",
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "anthropic/claude-3-5-sonnet",
|
||||
}
|
||||
],
|
||||
|
|
@ -349,12 +348,12 @@ def test_register_continues_after_individual_failure(monkeypatch, caplog):
|
|||
[
|
||||
{
|
||||
"id": 1,
|
||||
"provider": "OPENROUTER",
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "anthropic/claude-3-5-sonnet",
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"provider": "OPENAI",
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "custom-deployment",
|
||||
"litellm_params": {
|
||||
"base_model": "custom-deployment",
|
||||
|
|
@ -369,79 +368,3 @@ def test_register_continues_after_individual_failure(monkeypatch, caplog):
|
|||
|
||||
# The good config still registered.
|
||||
assert any("custom-deployment" in payload for payload in successful_calls)
|
||||
|
||||
|
||||
def test_vision_configs_registered_with_chat_shape(monkeypatch):
|
||||
"""``register_pricing_from_global_configs`` walks
|
||||
``GLOBAL_VISION_LLM_CONFIGS`` in addition to the chat configs so vision
|
||||
calls (during indexing) bill correctly. Vision configs use the same
|
||||
chat-shape token prices, but image-gen pricing is intentionally NOT
|
||||
registered here (handled via ``response_cost`` in LiteLLM).
|
||||
"""
|
||||
from app.config import config
|
||||
from app.services.pricing_registration import register_pricing_from_global_configs
|
||||
|
||||
spy = _patch_register(monkeypatch)
|
||||
_patch_openrouter_pricing(
|
||||
monkeypatch,
|
||||
{"openai/gpt-4o": {"prompt": "0.000005", "completion": "0.000015"}},
|
||||
)
|
||||
|
||||
# No chat configs — only vision. Proves the vision walk is a separate
|
||||
# iteration, not piggy-backed on the chat list.
|
||||
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [])
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_VISION_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "openai/gpt-4o",
|
||||
"billing_tier": "premium",
|
||||
"input_cost_per_token": 5e-6,
|
||||
"output_cost_per_token": 15e-6,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
register_pricing_from_global_configs()
|
||||
|
||||
assert "openrouter/openai/gpt-4o" in spy.all_keys
|
||||
payload_value = spy.calls[0]["openrouter/openai/gpt-4o"]
|
||||
assert payload_value["mode"] == "chat"
|
||||
assert payload_value["litellm_provider"] == "openrouter"
|
||||
assert payload_value["input_cost_per_token"] == pytest.approx(5e-6)
|
||||
assert payload_value["output_cost_per_token"] == pytest.approx(15e-6)
|
||||
|
||||
|
||||
def test_vision_with_inline_pricing_when_or_cache_missing(monkeypatch):
|
||||
"""If the OpenRouter pricing cache misses a vision model (different
|
||||
catalogue surface), the vision walk falls back to inline
|
||||
``input_cost_per_token``/``output_cost_per_token`` on the cfg itself.
|
||||
"""
|
||||
from app.config import config
|
||||
from app.services.pricing_registration import register_pricing_from_global_configs
|
||||
|
||||
spy = _patch_register(monkeypatch)
|
||||
_patch_openrouter_pricing(monkeypatch, {})
|
||||
|
||||
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [])
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_VISION_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "google/gemini-2.5-flash",
|
||||
"billing_tier": "premium",
|
||||
"input_cost_per_token": 1e-6,
|
||||
"output_cost_per_token": 4e-6,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
register_pricing_from_global_configs()
|
||||
|
||||
assert "openrouter/google/gemini-2.5-flash" in spy.all_keys
|
||||
|
|
|
|||
|
|
@ -1,107 +0,0 @@
|
|||
"""Unit tests for the shared ``api_base`` resolver.
|
||||
|
||||
The cascade exists so vision and image-gen call sites can't silently
|
||||
inherit ``litellm.api_base`` (commonly set by ``AZURE_OPENAI_ENDPOINT``)
|
||||
when an OpenRouter / Groq / etc. config ships an empty string. See
|
||||
``provider_api_base`` module docstring for the original repro
|
||||
(OpenRouter image-gen 404-ing against an Azure endpoint).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.provider_api_base import (
|
||||
PROVIDER_DEFAULT_API_BASE,
|
||||
PROVIDER_KEY_DEFAULT_API_BASE,
|
||||
resolve_api_base,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_config_value_wins_over_defaults():
|
||||
"""A non-empty config value is always returned verbatim, even when the
|
||||
provider has a default — the operator gets the last word."""
|
||||
result = resolve_api_base(
|
||||
provider="OPENROUTER",
|
||||
provider_prefix="openrouter",
|
||||
config_api_base="https://my-openrouter-mirror.example.com/v1",
|
||||
)
|
||||
assert result == "https://my-openrouter-mirror.example.com/v1"
|
||||
|
||||
|
||||
def test_provider_key_default_when_config_missing():
|
||||
"""``DEEPSEEK`` shares the ``openai`` LiteLLM prefix but has its own
|
||||
base URL — the provider-key map must take precedence over the prefix
|
||||
map so DeepSeek requests don't go to OpenAI."""
|
||||
result = resolve_api_base(
|
||||
provider="DEEPSEEK",
|
||||
provider_prefix="openai",
|
||||
config_api_base=None,
|
||||
)
|
||||
assert result == PROVIDER_KEY_DEFAULT_API_BASE["DEEPSEEK"]
|
||||
|
||||
|
||||
def test_provider_prefix_default_when_no_key_default():
|
||||
result = resolve_api_base(
|
||||
provider="OPENROUTER",
|
||||
provider_prefix="openrouter",
|
||||
config_api_base=None,
|
||||
)
|
||||
assert result == PROVIDER_DEFAULT_API_BASE["openrouter"]
|
||||
|
||||
|
||||
def test_unknown_provider_returns_none():
|
||||
"""When neither map matches we return ``None`` so the caller can let
|
||||
LiteLLM apply its own provider-integration default (Azure deployment
|
||||
URL, custom-provider URL, etc.)."""
|
||||
result = resolve_api_base(
|
||||
provider="SOMETHING_NEW",
|
||||
provider_prefix="something_new",
|
||||
config_api_base=None,
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_empty_string_config_treated_as_missing():
|
||||
"""The original bug: OpenRouter dynamic configs ship ``api_base=""``
|
||||
and downstream call sites use ``if cfg.get("api_base"):`` — empty
|
||||
strings are falsy in Python but the cascade has to step in anyway."""
|
||||
result = resolve_api_base(
|
||||
provider="OPENROUTER",
|
||||
provider_prefix="openrouter",
|
||||
config_api_base="",
|
||||
)
|
||||
assert result == PROVIDER_DEFAULT_API_BASE["openrouter"]
|
||||
|
||||
|
||||
def test_whitespace_only_config_treated_as_missing():
|
||||
"""A config value of ``" "`` is a configuration mistake — treat it
|
||||
as missing instead of forwarding whitespace to LiteLLM (which would
|
||||
almost certainly 404)."""
|
||||
result = resolve_api_base(
|
||||
provider="OPENROUTER",
|
||||
provider_prefix="openrouter",
|
||||
config_api_base=" ",
|
||||
)
|
||||
assert result == PROVIDER_DEFAULT_API_BASE["openrouter"]
|
||||
|
||||
|
||||
def test_provider_case_insensitive():
|
||||
"""Some call sites pass the provider lowercase (DB enum value), others
|
||||
uppercase (YAML key). Both must resolve."""
|
||||
upper = resolve_api_base(
|
||||
provider="DEEPSEEK", provider_prefix="openai", config_api_base=None
|
||||
)
|
||||
lower = resolve_api_base(
|
||||
provider="deepseek", provider_prefix="openai", config_api_base=None
|
||||
)
|
||||
assert upper == lower == PROVIDER_KEY_DEFAULT_API_BASE["DEEPSEEK"]
|
||||
|
||||
|
||||
def test_all_inputs_none_returns_none():
|
||||
assert (
|
||||
resolve_api_base(provider=None, provider_prefix=None, config_api_base=None)
|
||||
is None
|
||||
)
|
||||
|
|
@ -32,7 +32,7 @@ pytestmark = pytest.mark.unit
|
|||
def test_or_modalities_with_image_returns_true():
|
||||
assert (
|
||||
derive_supports_image_input(
|
||||
provider="OPENROUTER",
|
||||
provider="openrouter",
|
||||
model_name="openai/gpt-4o",
|
||||
openrouter_input_modalities=["text", "image"],
|
||||
)
|
||||
|
|
@ -43,7 +43,7 @@ def test_or_modalities_with_image_returns_true():
|
|||
def test_or_modalities_text_only_returns_false():
|
||||
assert (
|
||||
derive_supports_image_input(
|
||||
provider="OPENROUTER",
|
||||
provider="openrouter",
|
||||
model_name="deepseek/deepseek-v3.2-exp",
|
||||
openrouter_input_modalities=["text"],
|
||||
)
|
||||
|
|
@ -57,7 +57,7 @@ def test_or_modalities_empty_list_returns_false():
|
|||
to LiteLLM."""
|
||||
assert (
|
||||
derive_supports_image_input(
|
||||
provider="OPENROUTER",
|
||||
provider="openrouter",
|
||||
model_name="weird/empty-modalities",
|
||||
openrouter_input_modalities=[],
|
||||
)
|
||||
|
|
@ -70,7 +70,7 @@ def test_or_modalities_none_falls_through_to_litellm():
|
|||
to LiteLLM. Using ``openai/gpt-4o`` which is in LiteLLM's map."""
|
||||
assert (
|
||||
derive_supports_image_input(
|
||||
provider="OPENAI",
|
||||
provider="openai",
|
||||
model_name="gpt-4o",
|
||||
openrouter_input_modalities=None,
|
||||
)
|
||||
|
|
@ -86,7 +86,7 @@ def test_or_modalities_none_falls_through_to_litellm():
|
|||
def test_litellm_known_vision_model_returns_true():
|
||||
assert (
|
||||
derive_supports_image_input(
|
||||
provider="OPENAI",
|
||||
provider="openai",
|
||||
model_name="gpt-4o",
|
||||
)
|
||||
is True
|
||||
|
|
@ -100,7 +100,7 @@ def test_litellm_base_model_wins_over_model_name():
|
|||
doesn't know) would shadow the real capability."""
|
||||
assert (
|
||||
derive_supports_image_input(
|
||||
provider="AZURE_OPENAI",
|
||||
provider="azure",
|
||||
model_name="my-azure-deployment-id",
|
||||
base_model="gpt-4o",
|
||||
)
|
||||
|
|
@ -112,7 +112,7 @@ def test_litellm_unknown_model_default_allows():
|
|||
"""Default-allow on unknown — the safety net is the actual block."""
|
||||
assert (
|
||||
derive_supports_image_input(
|
||||
provider="CUSTOM",
|
||||
provider="custom",
|
||||
model_name="brand-new-model-x9-unmapped",
|
||||
custom_provider="brand_new_proxy",
|
||||
)
|
||||
|
|
@ -128,7 +128,7 @@ def test_litellm_known_text_only_returns_false():
|
|||
# Sanity: confirm the helper's negative path. We use a small model
|
||||
# known not to support vision per the map.
|
||||
result = derive_supports_image_input(
|
||||
provider="DEEPSEEK",
|
||||
provider="openai",
|
||||
model_name="deepseek-chat",
|
||||
)
|
||||
# We accept either False (LiteLLM said explicit no) or True
|
||||
|
|
@ -147,7 +147,7 @@ def test_litellm_known_text_only_returns_false():
|
|||
def test_is_known_text_only_returns_false_for_vision_model():
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
provider="openai",
|
||||
model_name="gpt-4o",
|
||||
)
|
||||
is False
|
||||
|
|
@ -160,7 +160,7 @@ def test_is_known_text_only_returns_false_for_unknown_model():
|
|||
fixing."""
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="CUSTOM",
|
||||
provider="custom",
|
||||
model_name="brand-new-model-x9-unmapped",
|
||||
custom_provider="brand_new_proxy",
|
||||
)
|
||||
|
|
@ -181,7 +181,7 @@ def test_is_known_text_only_returns_false_when_lookup_raises(monkeypatch):
|
|||
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
provider="openai",
|
||||
model_name="gpt-4o",
|
||||
)
|
||||
is False
|
||||
|
|
@ -201,7 +201,7 @@ def test_is_known_text_only_returns_true_on_explicit_false(monkeypatch):
|
|||
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
provider="openai",
|
||||
model_name="any-model",
|
||||
)
|
||||
is True
|
||||
|
|
@ -218,7 +218,7 @@ def test_is_known_text_only_returns_false_on_supports_vision_true(monkeypatch):
|
|||
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
provider="openai",
|
||||
model_name="any-model",
|
||||
)
|
||||
is False
|
||||
|
|
@ -237,7 +237,7 @@ def test_is_known_text_only_returns_false_on_missing_key(monkeypatch):
|
|||
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
provider="openai",
|
||||
model_name="any-model",
|
||||
)
|
||||
is False
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
"""Unit tests for the Auto (Fastest) quality scoring module."""
|
||||
"""Unit tests for the Auto quality scoring module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
|
@ -228,7 +228,7 @@ def test_static_score_or_recent_release_beats_year_old_same_provider():
|
|||
|
||||
def test_static_score_yaml_includes_operator_bonus():
|
||||
cfg = {
|
||||
"provider": "AZURE_OPENAI",
|
||||
"litellm_provider": "azure",
|
||||
"model_name": "gpt-5",
|
||||
"litellm_params": {"base_model": "azure/gpt-5"},
|
||||
}
|
||||
|
|
@ -238,7 +238,7 @@ def test_static_score_yaml_includes_operator_bonus():
|
|||
|
||||
def test_static_score_yaml_unknown_provider_still_carries_bonus():
|
||||
cfg = {
|
||||
"provider": "SOME_NEW_PROVIDER",
|
||||
"litellm_provider": "some_new_provider",
|
||||
"model_name": "weird-model",
|
||||
}
|
||||
score = static_score_yaml(cfg)
|
||||
|
|
@ -247,7 +247,7 @@ def test_static_score_yaml_unknown_provider_still_carries_bonus():
|
|||
|
||||
def test_static_score_yaml_clamped_0_to_100():
|
||||
cfg = {
|
||||
"provider": "AZURE_OPENAI",
|
||||
"litellm_provider": "azure",
|
||||
"model_name": "gpt-5",
|
||||
"litellm_params": {"base_model": "azure/gpt-5"},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -131,6 +131,10 @@ def test_serialized_calls_includes_cost_micros():
|
|||
assert serialized == [
|
||||
{
|
||||
"model": "m",
|
||||
"model_ref": None,
|
||||
"model_id": None,
|
||||
"display_name": None,
|
||||
"provider": None,
|
||||
"prompt_tokens": 1,
|
||||
"completion_tokens": 1,
|
||||
"total_tokens": 2,
|
||||
|
|
|
|||
|
|
@ -1,89 +0,0 @@
|
|||
"""Defense-in-depth: vision-LLM resolution must not leak ``api_base``
|
||||
defaults from ``litellm.api_base`` either.
|
||||
|
||||
Vision shares the same shape as image-gen — global YAML / OpenRouter
|
||||
dynamic configs ship ``api_base=""`` and the pre-fix ``get_vision_llm``
|
||||
call sites would silently drop the empty string and inherit
|
||||
``AZURE_OPENAI_ENDPOINT``. ``ChatLiteLLM(...)`` doesn't 404 on
|
||||
construction so we test the kwargs we hand to it instead.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_vision_llm_global_openrouter_sets_api_base():
|
||||
"""Global negative-ID branch: an OpenRouter vision config with
|
||||
``api_base=""`` must end up calling ``SanitizedChatLiteLLM`` with
|
||||
``api_base="https://openrouter.ai/api/v1"`` — never an empty string,
|
||||
never silently absent."""
|
||||
from app.services import llm_service
|
||||
|
||||
cfg = {
|
||||
"id": -30_001,
|
||||
"name": "GPT-4o Vision (OpenRouter)",
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "openai/gpt-4o",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "",
|
||||
"api_version": None,
|
||||
"litellm_params": {},
|
||||
"billing_tier": "free",
|
||||
}
|
||||
|
||||
search_space = MagicMock()
|
||||
search_space.id = 1
|
||||
search_space.user_id = "user-x"
|
||||
search_space.vision_llm_config_id = cfg["id"]
|
||||
|
||||
session = AsyncMock()
|
||||
scalars = MagicMock()
|
||||
scalars.first.return_value = search_space
|
||||
result = MagicMock()
|
||||
result.scalars.return_value = scalars
|
||||
session.execute.return_value = result
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
class FakeSanitized:
|
||||
def __init__(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.services.vision_llm_router_service.get_global_vision_llm_config",
|
||||
return_value=cfg,
|
||||
),
|
||||
patch(
|
||||
"app.agents.chat.runtime.llm_config.SanitizedChatLiteLLM",
|
||||
new=FakeSanitized,
|
||||
),
|
||||
):
|
||||
await llm_service.get_vision_llm(session=session, search_space_id=1)
|
||||
|
||||
assert captured.get("api_base") == "https://openrouter.ai/api/v1"
|
||||
assert captured["model"] == "openrouter/openai/gpt-4o"
|
||||
|
||||
|
||||
def test_vision_router_deployment_sets_api_base_when_config_empty():
|
||||
"""Auto-mode vision router: deployments are fed to ``litellm.Router``,
|
||||
so the resolver has to apply at deployment construction time too."""
|
||||
from app.services.vision_llm_router_service import VisionLLMRouterService
|
||||
|
||||
deployment = VisionLLMRouterService._config_to_deployment(
|
||||
{
|
||||
"model_name": "openai/gpt-4o",
|
||||
"provider": "OPENROUTER",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "",
|
||||
}
|
||||
)
|
||||
assert deployment is not None
|
||||
assert deployment["litellm_params"]["api_base"] == "https://openrouter.ai/api/v1"
|
||||
assert deployment["litellm_params"]["model"] == "openrouter/openai/gpt-4o"
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.llm_error_adapter import LLMErrorCategory, adapt_llm_exception
|
||||
from app.tasks.chat.streaming.errors.classifier import classify_stream_exception
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _exception_named(name: str, message: str) -> Exception:
|
||||
return type(name, (Exception,), {})(message)
|
||||
|
||||
|
||||
def test_adapter_classifies_authentication_error_by_class_name() -> None:
|
||||
exc = _exception_named("AuthenticationError", "provider rejected credentials")
|
||||
|
||||
adapted = adapt_llm_exception(exc)
|
||||
|
||||
assert adapted.category is LLMErrorCategory.AUTH_FAILED
|
||||
assert adapted.retryable is False
|
||||
assert adapted.user_message == "LLM authentication failed. Check your API key."
|
||||
|
||||
|
||||
def test_adapter_classifies_embedded_provider_401_payload() -> None:
|
||||
exc = RuntimeError(
|
||||
'litellm.AuthenticationError: OpenrouterException - {"error":{"message":"User not found.","code":401}}'
|
||||
)
|
||||
|
||||
adapted = adapt_llm_exception(exc)
|
||||
|
||||
assert adapted.category is LLMErrorCategory.AUTH_FAILED
|
||||
assert adapted.provider_status_code == 401
|
||||
|
||||
|
||||
def test_adapter_preserves_rate_limit_classification() -> None:
|
||||
exc = RuntimeError('{"error":{"message":"Slow down","code":429}}')
|
||||
|
||||
adapted = adapt_llm_exception(exc)
|
||||
|
||||
assert adapted.category is LLMErrorCategory.RATE_LIMITED
|
||||
assert adapted.retryable is True
|
||||
|
||||
|
||||
def test_stream_classifier_maps_model_auth_to_stable_code() -> None:
|
||||
exc = RuntimeError(
|
||||
'litellm.AuthenticationError: OpenrouterException - {"error":{"message":"User not found.","code":401}}'
|
||||
)
|
||||
|
||||
kind, code, severity, expected, message, extra = classify_stream_exception(
|
||||
exc,
|
||||
flow_label="chat",
|
||||
)
|
||||
|
||||
assert kind == "model_auth_failed"
|
||||
assert code == "MODEL_AUTH_FAILED"
|
||||
assert severity == "warn"
|
||||
assert expected is True
|
||||
assert "API key" in message
|
||||
assert extra == {
|
||||
"provider_error_category": "auth_failed",
|
||||
"provider_status_code": 401,
|
||||
}
|
||||
|
||||
|
||||
def test_stream_classifier_keeps_unknown_errors_generic() -> None:
|
||||
exc = RuntimeError("database exploded")
|
||||
|
||||
kind, code, severity, expected, message, extra = classify_stream_exception(
|
||||
exc,
|
||||
flow_label="chat",
|
||||
)
|
||||
|
||||
assert kind == "server_error"
|
||||
assert code == "SERVER_ERROR"
|
||||
assert severity == "error"
|
||||
assert expected is False
|
||||
assert message == "Error during chat: database exploded"
|
||||
assert extra is None
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
"""Unit tests for provider-safe LLM history normalization."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.tasks.chat.llm_history_normalizer import (
|
||||
assistant_content_to_llm_text,
|
||||
user_content_to_llm_content,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_assistant_ui_parts_drop_thinking_steps_for_llm_history() -> None:
|
||||
content = [
|
||||
{"type": "data-thinking-steps", "data": [{"id": "thinking-1"}]},
|
||||
{"type": "text", "text": "visible answer"},
|
||||
]
|
||||
|
||||
assert assistant_content_to_llm_text(content) == "visible answer"
|
||||
|
||||
|
||||
def test_provider_thinking_blocks_are_not_replayed_to_llm() -> None:
|
||||
content = [
|
||||
{"type": "thinking", "thinking": "private reasoning"},
|
||||
{"type": "text", "text": "final answer"},
|
||||
]
|
||||
|
||||
assert assistant_content_to_llm_text(content) == "final answer"
|
||||
|
||||
|
||||
def test_unknown_assistant_blocks_are_dropped() -> None:
|
||||
content = [
|
||||
{"type": "redacted_thinking", "data": "hidden"},
|
||||
{"type": "tool_use", "name": "search"},
|
||||
{"type": "text", "text": "kept"},
|
||||
]
|
||||
|
||||
assert assistant_content_to_llm_text(content) == "kept"
|
||||
|
||||
|
||||
def test_user_images_convert_to_openai_compatible_image_url_blocks() -> None:
|
||||
content = [
|
||||
{"type": "text", "text": "look"},
|
||||
{"type": "image", "image": "data:image/png;base64,abc"},
|
||||
]
|
||||
|
||||
assert user_content_to_llm_content(content, allow_images=True) == [
|
||||
{"type": "text", "text": "look"},
|
||||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
|
||||
]
|
||||
|
||||
|
||||
def test_user_images_can_be_dropped_for_text_only_history() -> None:
|
||||
content = [
|
||||
{"type": "text", "text": "look"},
|
||||
{"type": "image", "image": "data:image/png;base64,abc"},
|
||||
]
|
||||
|
||||
assert user_content_to_llm_content(content, allow_images=False) == "look"
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
"""Unit tests for final assistant message part normalization."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
|
||||
from app.tasks.chat.message_parts_normalizer import (
|
||||
final_assistant_parts_from_messages,
|
||||
merge_streamed_and_final_parts,
|
||||
normalize_ai_message_to_parts,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_string_ai_message_content_becomes_text_part() -> None:
|
||||
assert normalize_ai_message_to_parts(AIMessage(content="hello")) == [
|
||||
{"type": "text", "text": "hello"}
|
||||
]
|
||||
|
||||
|
||||
def test_deepseek_thinking_plus_text_blocks_backfill_only_text() -> None:
|
||||
message = AIMessage(
|
||||
content=[
|
||||
{"type": "thinking", "thinking": "hidden reasoning"},
|
||||
{"type": "text", "text": "Yo bro! What's up?"},
|
||||
],
|
||||
additional_kwargs={"reasoning_content": "hidden reasoning"},
|
||||
)
|
||||
|
||||
assert normalize_ai_message_to_parts(message) == [
|
||||
{"type": "text", "text": "Yo bro! What's up?"}
|
||||
]
|
||||
|
||||
|
||||
def test_final_parts_use_last_ai_message_and_skip_trailing_tool_messages() -> None:
|
||||
messages = [
|
||||
HumanMessage(content="ask"),
|
||||
AIMessage(content="draft"),
|
||||
ToolMessage(content="tool output", tool_call_id="tc-1"),
|
||||
AIMessage(content=[{"type": "text", "text": "final answer"}]),
|
||||
ToolMessage(content="trailing tool noise", tool_call_id="tc-2"),
|
||||
]
|
||||
|
||||
assert final_assistant_parts_from_messages(messages) == [
|
||||
{"type": "text", "text": "final answer"}
|
||||
]
|
||||
|
||||
|
||||
def test_merge_adds_final_text_when_stream_only_has_thinking_steps() -> None:
|
||||
streamed = [
|
||||
{
|
||||
"type": "data-thinking-steps",
|
||||
"data": [{"id": "thinking-1", "status": "completed"}],
|
||||
}
|
||||
]
|
||||
final = [{"type": "text", "text": "visible answer"}]
|
||||
|
||||
assert merge_streamed_and_final_parts(streamed, final) == [*streamed, *final]
|
||||
|
||||
|
||||
def test_merge_does_not_duplicate_when_stream_already_has_text() -> None:
|
||||
streamed = [{"type": "text", "text": "streamed answer"}]
|
||||
final = [{"type": "text", "text": "final answer"}]
|
||||
|
||||
assert merge_streamed_and_final_parts(streamed, final) == streamed
|
||||
|
|
@ -35,7 +35,7 @@ def test_safety_net_does_not_fire_for_azure_gpt_4o():
|
|||
it text-only."""
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="AZURE_OPENAI",
|
||||
provider="azure",
|
||||
model_name="my-azure-deployment",
|
||||
base_model="gpt-4o",
|
||||
)
|
||||
|
|
@ -49,7 +49,7 @@ def test_safety_net_does_not_fire_for_unknown_model():
|
|||
LiteLLM doesn't know about must flow through to the provider."""
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="CUSTOM",
|
||||
provider="custom",
|
||||
custom_provider="brand_new_proxy",
|
||||
model_name="brand-new-model-x9",
|
||||
)
|
||||
|
|
@ -69,7 +69,7 @@ def test_safety_net_does_not_fire_when_lookup_raises(monkeypatch):
|
|||
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
provider="openai",
|
||||
model_name="gpt-4o",
|
||||
)
|
||||
is False
|
||||
|
|
@ -88,7 +88,7 @@ def test_safety_net_fires_only_on_explicit_false(monkeypatch):
|
|||
monkeypatch.setattr(pc.litellm, "get_model_info", _info_explicit_false)
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
provider="openai",
|
||||
model_name="text-only-stub",
|
||||
)
|
||||
is True
|
||||
|
|
@ -100,7 +100,7 @@ def test_safety_net_fires_only_on_explicit_false(monkeypatch):
|
|||
monkeypatch.setattr(pc.litellm, "get_model_info", _info_true)
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
provider="openai",
|
||||
model_name="vision-stub",
|
||||
)
|
||||
is False
|
||||
|
|
@ -112,7 +112,7 @@ def test_safety_net_fires_only_on_explicit_false(monkeypatch):
|
|||
monkeypatch.setattr(pc.litellm, "get_model_info", _info_missing)
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
provider="openai",
|
||||
model_name="missing-key-stub",
|
||||
)
|
||||
is False
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue