mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-07 06:42:39 +02:00
Merge upstream/dev into feature/multi-agent
This commit is contained in:
commit
5119915f4f
278 changed files with 34669 additions and 8970 deletions
479
surfsense_backend/app/services/auto_model_pin_service.py
Normal file
479
surfsense_backend/app/services/auto_model_pin_service.py
Normal file
|
|
@ -0,0 +1,479 @@
|
|||
"""Resolve and persist Auto (Fastest) model pins per chat thread.
|
||||
|
||||
Auto (Fastest) is represented by ``agent_llm_id == 0``. For chat threads we
|
||||
resolve that virtual mode to one concrete global LLM config exactly once and
|
||||
persist the chosen config id on ``new_chat_threads.pinned_llm_config_id`` so
|
||||
subsequent turns are stable.
|
||||
|
||||
Single-writer invariant: this module is the only writer of
|
||||
``NewChatThread.pinned_llm_config_id`` (aside from the bulk clear in
|
||||
``search_spaces_routes`` when a search space's ``agent_llm_id`` changes).
|
||||
Therefore a non-NULL value unambiguously means "this thread has an
|
||||
Auto-resolved pin"; no separate source/policy column is needed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.db import NewChatThread
|
||||
from app.services.quality_score import _QUALITY_TOP_K
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
AUTO_FASTEST_ID = 0
|
||||
AUTO_FASTEST_MODE = "auto_fastest"
|
||||
_RUNTIME_COOLDOWN_SECONDS = 600
|
||||
_HEALTHY_TTL_SECONDS = 45
|
||||
|
||||
# In-memory runtime cooldown map for configs that recently hard-failed at
|
||||
# provider runtime (e.g. OpenRouter 429 on a pinned free model). This keeps
|
||||
# the same unhealthy config from being reselected immediately during repair.
|
||||
_runtime_cooldown_until: dict[int, float] = {}
|
||||
_runtime_cooldown_lock = threading.Lock()
|
||||
|
||||
# Short-TTL "recently healthy" cache for configs that just passed a runtime
|
||||
# preflight ping. Lets back-to-back turns on the same model skip the probe
|
||||
# without eroding correctness — entries auto-expire and are wiped any time
|
||||
# the same config is cooled down or the OR catalogue is refreshed.
|
||||
_healthy_until: dict[int, float] = {}
|
||||
_healthy_lock = threading.Lock()
|
||||
|
||||
|
||||
@dataclass
|
||||
class AutoPinResolution:
|
||||
resolved_llm_config_id: int
|
||||
resolved_tier: str
|
||||
from_existing_pin: bool
|
||||
|
||||
|
||||
def _is_usable_global_config(cfg: dict) -> bool:
|
||||
return bool(
|
||||
cfg.get("id") is not None
|
||||
and cfg.get("model_name")
|
||||
and cfg.get("provider")
|
||||
and cfg.get("api_key")
|
||||
)
|
||||
|
||||
|
||||
def _prune_runtime_cooldowns(now_ts: float | None = None) -> None:
|
||||
now = time.time() if now_ts is None else now_ts
|
||||
stale = [cid for cid, until in _runtime_cooldown_until.items() if until <= now]
|
||||
for cid in stale:
|
||||
_runtime_cooldown_until.pop(cid, None)
|
||||
|
||||
|
||||
def _is_runtime_cooled_down(config_id: int) -> bool:
|
||||
with _runtime_cooldown_lock:
|
||||
_prune_runtime_cooldowns()
|
||||
return config_id in _runtime_cooldown_until
|
||||
|
||||
|
||||
def mark_runtime_cooldown(
|
||||
config_id: int,
|
||||
*,
|
||||
reason: str = "rate_limited",
|
||||
cooldown_seconds: int = _RUNTIME_COOLDOWN_SECONDS,
|
||||
) -> None:
|
||||
"""Temporarily suppress a config from Auto selection.
|
||||
|
||||
Used by runtime error handlers (e.g. OpenRouter 429) so an already pinned
|
||||
config that is currently unhealthy does not get immediately reused on the
|
||||
same thread during repair.
|
||||
"""
|
||||
if cooldown_seconds <= 0:
|
||||
cooldown_seconds = _RUNTIME_COOLDOWN_SECONDS
|
||||
until = time.time() + int(cooldown_seconds)
|
||||
with _runtime_cooldown_lock:
|
||||
_runtime_cooldown_until[int(config_id)] = until
|
||||
_prune_runtime_cooldowns()
|
||||
# A cooled cfg can never be "recently healthy"; drop any stale credit so
|
||||
# the next turn that resolves to it (after cooldown) re-runs preflight.
|
||||
clear_healthy(int(config_id))
|
||||
logger.info(
|
||||
"auto_pin_runtime_cooled_down config_id=%s reason=%s cooldown_seconds=%s",
|
||||
config_id,
|
||||
reason,
|
||||
cooldown_seconds,
|
||||
)
|
||||
|
||||
|
||||
def clear_runtime_cooldown(config_id: int | None = None) -> None:
|
||||
"""Test/ops helper to clear runtime cooldown entries."""
|
||||
with _runtime_cooldown_lock:
|
||||
if config_id is None:
|
||||
_runtime_cooldown_until.clear()
|
||||
return
|
||||
_runtime_cooldown_until.pop(int(config_id), None)
|
||||
|
||||
|
||||
def _prune_healthy(now_ts: float | None = None) -> None:
|
||||
now = time.time() if now_ts is None else now_ts
|
||||
stale = [cid for cid, until in _healthy_until.items() if until <= now]
|
||||
for cid in stale:
|
||||
_healthy_until.pop(cid, None)
|
||||
|
||||
|
||||
def is_recently_healthy(config_id: int) -> bool:
|
||||
"""Return True if ``config_id`` passed preflight within the TTL window."""
|
||||
with _healthy_lock:
|
||||
_prune_healthy()
|
||||
return int(config_id) in _healthy_until
|
||||
|
||||
|
||||
def mark_healthy(
|
||||
config_id: int,
|
||||
*,
|
||||
ttl_seconds: int = _HEALTHY_TTL_SECONDS,
|
||||
) -> None:
|
||||
"""Record that ``config_id`` just passed a preflight probe.
|
||||
|
||||
Subsequent calls within ``ttl_seconds`` can skip the preflight ping. The
|
||||
healthy state is intentionally process-local — it's a latency hint, not a
|
||||
correctness primitive — so multi-worker drift is acceptable.
|
||||
"""
|
||||
if ttl_seconds <= 0:
|
||||
ttl_seconds = _HEALTHY_TTL_SECONDS
|
||||
until = time.time() + int(ttl_seconds)
|
||||
with _healthy_lock:
|
||||
_healthy_until[int(config_id)] = until
|
||||
_prune_healthy()
|
||||
|
||||
|
||||
def clear_healthy(config_id: int | None = None) -> None:
|
||||
"""Drop one (or all) healthy-cache entries.
|
||||
|
||||
Called from runtime cooldown and OR catalogue refresh so a freshly cooled
|
||||
or replaced config never carries stale "healthy" credit.
|
||||
"""
|
||||
with _healthy_lock:
|
||||
if config_id is None:
|
||||
_healthy_until.clear()
|
||||
return
|
||||
_healthy_until.pop(int(config_id), None)
|
||||
|
||||
|
||||
def _cfg_supports_image_input(cfg: dict) -> bool:
|
||||
"""True if the global cfg can accept image inputs.
|
||||
|
||||
Prefers the explicit ``supports_image_input`` flag (set by the YAML
|
||||
loader / OpenRouter integration). Falls back to a LiteLLM lookup so
|
||||
a YAML entry whose flag was somehow stripped doesn't get wrongly
|
||||
excluded. Default-allows on unknown — the streaming-task safety net
|
||||
is the actual block, not this filter.
|
||||
"""
|
||||
if "supports_image_input" in cfg:
|
||||
return bool(cfg.get("supports_image_input"))
|
||||
# Lazy import: provider_capabilities -> llm_config -> services chain;
|
||||
# importing at module load would create an init-order cycle through
|
||||
# ``app.config``.
|
||||
from app.services.provider_capabilities import derive_supports_image_input
|
||||
|
||||
cfg_litellm_params = cfg.get("litellm_params") or {}
|
||||
base_model = (
|
||||
cfg_litellm_params.get("base_model")
|
||||
if isinstance(cfg_litellm_params, dict)
|
||||
else None
|
||||
)
|
||||
return derive_supports_image_input(
|
||||
provider=cfg.get("provider"),
|
||||
model_name=cfg.get("model_name"),
|
||||
base_model=base_model,
|
||||
custom_provider=cfg.get("custom_provider"),
|
||||
)
|
||||
|
||||
|
||||
def _global_candidates(*, requires_image_input: bool = False) -> list[dict]:
|
||||
"""Return Auto-eligible global cfgs.
|
||||
|
||||
Drops cfgs flagged ``health_gated`` (best non-null OpenRouter uptime
|
||||
below ``_HEALTH_GATE_UPTIME_PCT``) so chronically broken providers
|
||||
can't be picked as the thread's pin. Also excludes configs currently
|
||||
in runtime cooldown (e.g. temporary 429 bursts).
|
||||
|
||||
When ``requires_image_input`` is True (image turn), additionally
|
||||
filters out configs whose ``supports_image_input`` resolves to False
|
||||
so a text-only deployment can't be pinned for an image request.
|
||||
"""
|
||||
candidates = [
|
||||
cfg
|
||||
for cfg in config.GLOBAL_LLM_CONFIGS
|
||||
if _is_usable_global_config(cfg)
|
||||
and not cfg.get("health_gated")
|
||||
and not _is_runtime_cooled_down(int(cfg.get("id", 0)))
|
||||
and (not requires_image_input or _cfg_supports_image_input(cfg))
|
||||
]
|
||||
return sorted(candidates, key=lambda c: int(c.get("id", 0)))
|
||||
|
||||
|
||||
def _tier_of(cfg: dict) -> str:
|
||||
return str(cfg.get("billing_tier", "free")).lower()
|
||||
|
||||
|
||||
def _is_preferred_premium_auto_config(cfg: dict) -> bool:
|
||||
"""Return True for the operator-preferred premium Auto model."""
|
||||
return (
|
||||
_tier_of(cfg) == "premium"
|
||||
and str(cfg.get("provider", "")).upper() == "AZURE_OPENAI"
|
||||
and str(cfg.get("model_name", "")).lower() == "gpt-5.4"
|
||||
)
|
||||
|
||||
|
||||
def _select_pin(eligible: list[dict], thread_id: int) -> tuple[dict, int]:
|
||||
"""Pick a config with quality-first ranking + deterministic spread.
|
||||
|
||||
Tier policy is lock-first: prefer Tier A (operator-curated YAML)
|
||||
cfgs and only fall through to Tier B/C (dynamic OpenRouter) if no
|
||||
Tier A cfg is eligible after upstream filters. Within the locked
|
||||
pool, sort by ``quality_score`` and pick from the top-K via
|
||||
``SHA256(thread_id)`` so different new threads spread across the
|
||||
best models without ever picking a low-ranked one.
|
||||
|
||||
Returns ``(chosen_cfg, top_k_size)``. ``top_k_size`` is exposed for
|
||||
structured logging in the caller.
|
||||
"""
|
||||
tier_a = [c for c in eligible if c.get("auto_pin_tier") in (None, "A")]
|
||||
pool = tier_a if tier_a else eligible
|
||||
pool = sorted(pool, key=lambda c: -int(c.get("quality_score") or 0))
|
||||
top_k = pool[:_QUALITY_TOP_K]
|
||||
digest = hashlib.sha256(f"{AUTO_FASTEST_MODE}:{thread_id}".encode()).digest()
|
||||
idx = int.from_bytes(digest[:8], "big") % len(top_k)
|
||||
return top_k[idx], len(top_k)
|
||||
|
||||
|
||||
def _to_uuid(user_id: str | UUID | None) -> UUID | None:
|
||||
if user_id is None:
|
||||
return None
|
||||
if isinstance(user_id, UUID):
|
||||
return user_id
|
||||
try:
|
||||
return UUID(str(user_id))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
async def _is_premium_eligible(
|
||||
session: AsyncSession, user_id: str | UUID | None
|
||||
) -> bool:
|
||||
parsed = _to_uuid(user_id)
|
||||
if parsed is None:
|
||||
return False
|
||||
usage = await TokenQuotaService.premium_get_usage(session, parsed)
|
||||
return bool(usage.allowed)
|
||||
|
||||
|
||||
async def resolve_or_get_pinned_llm_config_id(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
thread_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str | UUID | None,
|
||||
selected_llm_config_id: int,
|
||||
force_repin_free: bool = False,
|
||||
exclude_config_ids: set[int] | None = None,
|
||||
requires_image_input: bool = False,
|
||||
) -> AutoPinResolution:
|
||||
"""Resolve Auto (Fastest) to one concrete config id and persist the pin.
|
||||
|
||||
For non-auto selections, this function clears any existing pin and returns
|
||||
the selected id as-is.
|
||||
|
||||
When ``requires_image_input`` is True (the current turn carries an
|
||||
``image_url`` block), the candidate pool is filtered to vision-capable
|
||||
cfgs and any existing pin that can't accept image input is treated as
|
||||
invalid (force re-pin). If no vision-capable cfg is available the
|
||||
function raises ``ValueError`` so the streaming task surfaces the same
|
||||
friendly ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` error instead of
|
||||
silently routing the image to a text-only deployment.
|
||||
"""
|
||||
thread = (
|
||||
(
|
||||
await session.execute(
|
||||
select(NewChatThread)
|
||||
.where(NewChatThread.id == thread_id)
|
||||
.with_for_update(of=NewChatThread)
|
||||
)
|
||||
)
|
||||
.unique()
|
||||
.scalar_one_or_none()
|
||||
)
|
||||
if thread is None:
|
||||
raise ValueError(f"Thread {thread_id} not found")
|
||||
if thread.search_space_id != search_space_id:
|
||||
raise ValueError(
|
||||
f"Thread {thread_id} does not belong to search space {search_space_id}"
|
||||
)
|
||||
|
||||
# Explicit model selected: clear any stale pin.
|
||||
if selected_llm_config_id != AUTO_FASTEST_ID:
|
||||
if thread.pinned_llm_config_id is not None:
|
||||
thread.pinned_llm_config_id = None
|
||||
await session.commit()
|
||||
return AutoPinResolution(
|
||||
resolved_llm_config_id=selected_llm_config_id,
|
||||
resolved_tier="explicit",
|
||||
from_existing_pin=False,
|
||||
)
|
||||
|
||||
excluded_ids = {int(cid) for cid in (exclude_config_ids or set())}
|
||||
candidates = [
|
||||
c
|
||||
for c in _global_candidates(requires_image_input=requires_image_input)
|
||||
if int(c.get("id", 0)) not in excluded_ids
|
||||
]
|
||||
if not candidates:
|
||||
if requires_image_input:
|
||||
# Distinguish the "no vision-capable cfg" case from generic
|
||||
# "no usable cfg" so the streaming task can map this to the
|
||||
# MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT SSE error.
|
||||
raise ValueError(
|
||||
"No vision-capable global LLM configs are available for Auto mode"
|
||||
)
|
||||
raise ValueError("No usable global LLM configs are available for Auto mode")
|
||||
candidate_by_id = {int(c["id"]): c for c in candidates}
|
||||
|
||||
# Reuse an existing valid pin without re-checking current quota (no silent
|
||||
# tier switch), unless the caller explicitly requests a forced repin to free
|
||||
# *or* the turn requires image input but the pin can't handle it.
|
||||
pinned_id = thread.pinned_llm_config_id
|
||||
if (
|
||||
not force_repin_free
|
||||
and pinned_id is not None
|
||||
and int(pinned_id) in candidate_by_id
|
||||
):
|
||||
pinned_cfg = candidate_by_id[int(pinned_id)]
|
||||
logger.info(
|
||||
"auto_pin_reused thread_id=%s search_space_id=%s resolved_config_id=%s tier=%s",
|
||||
thread_id,
|
||||
search_space_id,
|
||||
pinned_id,
|
||||
_tier_of(pinned_cfg),
|
||||
)
|
||||
logger.info(
|
||||
"auto_pin_resolved thread_id=%s config_id=%s tier=%s "
|
||||
"auto_pin_tier=%s score=%s top_k_size=0 from_existing_pin=True",
|
||||
thread_id,
|
||||
pinned_id,
|
||||
_tier_of(pinned_cfg),
|
||||
pinned_cfg.get("auto_pin_tier", "?"),
|
||||
int(pinned_cfg.get("quality_score") or 0),
|
||||
)
|
||||
return AutoPinResolution(
|
||||
resolved_llm_config_id=int(pinned_id),
|
||||
resolved_tier=_tier_of(pinned_cfg),
|
||||
from_existing_pin=True,
|
||||
)
|
||||
if pinned_id is not None:
|
||||
# If the pin is *only* invalid because it can't handle the image
|
||||
# turn (it's still a healthy, usable config in the broader pool),
|
||||
# log that explicitly so operators can correlate the re-pin with
|
||||
# the user's image attachment instead of suspecting a cooldown.
|
||||
if requires_image_input:
|
||||
try:
|
||||
pinned_global = next(
|
||||
c
|
||||
for c in config.GLOBAL_LLM_CONFIGS
|
||||
if int(c.get("id", 0)) == int(pinned_id)
|
||||
)
|
||||
except StopIteration:
|
||||
pinned_global = None
|
||||
if pinned_global is not None and not _cfg_supports_image_input(
|
||||
pinned_global
|
||||
):
|
||||
logger.info(
|
||||
"auto_pin_repinned_for_image thread_id=%s search_space_id=%s "
|
||||
"previous_config_id=%s",
|
||||
thread_id,
|
||||
search_space_id,
|
||||
pinned_id,
|
||||
)
|
||||
logger.info(
|
||||
"auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s",
|
||||
thread_id,
|
||||
search_space_id,
|
||||
pinned_id,
|
||||
)
|
||||
|
||||
premium_eligible = (
|
||||
False if force_repin_free else await _is_premium_eligible(session, user_id)
|
||||
)
|
||||
if premium_eligible:
|
||||
premium_candidates = [c for c in candidates if _tier_of(c) == "premium"]
|
||||
preferred_premium = [
|
||||
c for c in premium_candidates if _is_preferred_premium_auto_config(c)
|
||||
]
|
||||
eligible = preferred_premium or premium_candidates
|
||||
else:
|
||||
eligible = [c for c in candidates if _tier_of(c) != "premium"]
|
||||
|
||||
if not eligible:
|
||||
if requires_image_input:
|
||||
raise ValueError(
|
||||
"Auto mode could not find a vision-capable LLM config for this user and quota state"
|
||||
)
|
||||
raise ValueError(
|
||||
"Auto mode could not find an eligible LLM config for this user and quota state"
|
||||
)
|
||||
|
||||
selected_cfg, top_k_size = _select_pin(eligible, thread_id)
|
||||
selected_id = int(selected_cfg["id"])
|
||||
selected_tier = _tier_of(selected_cfg)
|
||||
|
||||
thread.pinned_llm_config_id = selected_id
|
||||
await session.commit()
|
||||
|
||||
if force_repin_free:
|
||||
logger.info(
|
||||
"auto_pin_forced_free_repin thread_id=%s search_space_id=%s previous_config_id=%s resolved_config_id=%s",
|
||||
thread_id,
|
||||
search_space_id,
|
||||
pinned_id,
|
||||
selected_id,
|
||||
)
|
||||
|
||||
if pinned_id is None:
|
||||
logger.info(
|
||||
"auto_pin_created thread_id=%s search_space_id=%s resolved_config_id=%s tier=%s premium_eligible=%s",
|
||||
thread_id,
|
||||
search_space_id,
|
||||
selected_id,
|
||||
selected_tier,
|
||||
premium_eligible,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"auto_pin_repaired thread_id=%s search_space_id=%s previous_config_id=%s resolved_config_id=%s tier=%s premium_eligible=%s",
|
||||
thread_id,
|
||||
search_space_id,
|
||||
pinned_id,
|
||||
selected_id,
|
||||
selected_tier,
|
||||
premium_eligible,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"auto_pin_resolved thread_id=%s config_id=%s tier=%s "
|
||||
"auto_pin_tier=%s score=%s top_k_size=%d from_existing_pin=False",
|
||||
thread_id,
|
||||
selected_id,
|
||||
selected_tier,
|
||||
selected_cfg.get("auto_pin_tier", "?"),
|
||||
int(selected_cfg.get("quality_score") or 0),
|
||||
top_k_size,
|
||||
)
|
||||
|
||||
return AutoPinResolution(
|
||||
resolved_llm_config_id=selected_id,
|
||||
resolved_tier=selected_tier,
|
||||
from_existing_pin=False,
|
||||
)
|
||||
566
surfsense_backend/app/services/billable_calls.py
Normal file
566
surfsense_backend/app/services/billable_calls.py
Normal file
|
|
@ -0,0 +1,566 @@
|
|||
"""
|
||||
Per-call billable wrapper for image generation, vision LLM extraction, and
|
||||
any other short-lived premium operation that must charge against the user's
|
||||
shared premium credit pool.
|
||||
|
||||
The ``billable_call`` async context manager encapsulates the standard
|
||||
"reserve → execute → finalize / release → record audit row" lifecycle in a
|
||||
single primitive so callers (the image-generation REST route and the
|
||||
vision-LLM wrapper used during indexing) don't have to re-implement it.
|
||||
|
||||
KEY DESIGN POINTS (issue A, B):
|
||||
|
||||
1. **Session isolation.** ``billable_call`` takes no caller transaction.
|
||||
All ``TokenQuotaService.premium_*`` calls and the audit-row insert run
|
||||
inside their own session context. Route callers use
|
||||
``shielded_async_session()`` by default; Celery callers can provide a
|
||||
worker-loop-safe session factory. This guarantees that quota
|
||||
commit/rollback can never accidentally flush or roll back rows the caller
|
||||
has staged in its main session (e.g. a freshly-created
|
||||
``ImageGeneration`` row).
|
||||
|
||||
2. **ContextVar safety.** The accumulator is scoped via
|
||||
:func:`scoped_turn` (which uses ``ContextVar.reset(token)``), so a
|
||||
nested ``billable_call`` inside an outer chat turn cannot corrupt the
|
||||
chat turn's accumulator.
|
||||
|
||||
3. **Free configs are still audited.** Free calls bypass the reserve /
|
||||
finalize dance entirely but still record a ``TokenUsage`` audit row with
|
||||
the LiteLLM-reported ``cost_micros``. This keeps the cost-attribution
|
||||
pipeline complete for analytics even when nothing is debited.
|
||||
|
||||
4. **Quota denial raises ``QuotaInsufficientError``.** The route handler is
|
||||
responsible for translating that into HTTP 402. We *do not* catch the
|
||||
denial inside ``billable_call`` — letting it propagate also prevents
|
||||
the image-generation route from creating an ``ImageGeneration`` row
|
||||
for a request that never actually ran.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from contextlib import AbstractAsyncContextManager, asynccontextmanager, suppress
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.db import shielded_async_session
|
||||
from app.services.token_quota_service import (
|
||||
TokenQuotaService,
|
||||
estimate_call_reserve_micros,
|
||||
)
|
||||
from app.services.token_tracking_service import (
|
||||
TurnTokenAccumulator,
|
||||
record_token_usage,
|
||||
scoped_turn,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
AUDIT_TIMEOUT_SECONDS = 10.0
|
||||
BACKGROUND_ARTIFACT_USAGE_TYPES = frozenset(
|
||||
{"video_presentation_generation", "podcast_generation"}
|
||||
)
|
||||
BillableSessionFactory = Callable[[], AbstractAsyncContextManager[AsyncSession]]
|
||||
|
||||
|
||||
class QuotaInsufficientError(Exception):
|
||||
"""Raised when ``TokenQuotaService.premium_reserve`` denies a billable
|
||||
call because the user has exhausted their premium credit pool.
|
||||
|
||||
The route handler should catch this and return HTTP 402 Payment
|
||||
Required (or the equivalent for the surface area). Outside of the HTTP
|
||||
layer (e.g. the ``QuotaCheckedVisionLLM`` wrapper used during indexing)
|
||||
callers may catch this and degrade gracefully — e.g. fall back to OCR
|
||||
when vision is unavailable.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
usage_type: str,
|
||||
used_micros: int,
|
||||
limit_micros: int,
|
||||
remaining_micros: int,
|
||||
) -> None:
|
||||
self.usage_type = usage_type
|
||||
self.used_micros = used_micros
|
||||
self.limit_micros = limit_micros
|
||||
self.remaining_micros = remaining_micros
|
||||
super().__init__(
|
||||
f"Premium credit exhausted for {usage_type}: "
|
||||
f"used={used_micros} limit={limit_micros} remaining={remaining_micros} (micro-USD)"
|
||||
)
|
||||
|
||||
|
||||
class BillingSettlementError(Exception):
|
||||
"""Raised when a premium call completed but credit settlement failed."""
|
||||
|
||||
def __init__(self, *, usage_type: str, user_id: UUID, cause: Exception) -> None:
|
||||
self.usage_type = usage_type
|
||||
self.user_id = user_id
|
||||
super().__init__(
|
||||
f"Failed to settle premium credit for {usage_type} user={user_id}: {cause}"
|
||||
)
|
||||
|
||||
|
||||
async def _rollback_safely(session: AsyncSession) -> None:
|
||||
rollback = getattr(session, "rollback", None)
|
||||
if rollback is not None:
|
||||
with suppress(Exception):
|
||||
await rollback()
|
||||
|
||||
|
||||
async def _record_audit_best_effort(
|
||||
*,
|
||||
session_factory: BillableSessionFactory,
|
||||
usage_type: str,
|
||||
search_space_id: int,
|
||||
user_id: UUID,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
cost_micros: int,
|
||||
model_breakdown: dict[str, Any],
|
||||
call_details: dict[str, Any] | None,
|
||||
thread_id: int | None,
|
||||
message_id: int | None,
|
||||
audit_label: str,
|
||||
timeout_seconds: float = AUDIT_TIMEOUT_SECONDS,
|
||||
) -> None:
|
||||
"""Persist a TokenUsage row without letting audit failure block callers.
|
||||
|
||||
Premium settlement is mandatory, but TokenUsage is an audit trail. If the
|
||||
audit insert or commit hangs, user-facing artifacts such as videos and
|
||||
podcasts must still be able to transition to READY after settlement.
|
||||
"""
|
||||
audit_thread_id = (
|
||||
None if usage_type in BACKGROUND_ARTIFACT_USAGE_TYPES else thread_id
|
||||
)
|
||||
|
||||
async def _persist() -> None:
|
||||
logger.info(
|
||||
"[billable_call] audit start label=%s usage_type=%s user=%s thread=%s "
|
||||
"total_tokens=%d cost_micros=%d",
|
||||
audit_label,
|
||||
usage_type,
|
||||
user_id,
|
||||
audit_thread_id,
|
||||
total_tokens,
|
||||
cost_micros,
|
||||
)
|
||||
async with session_factory() as audit_session:
|
||||
try:
|
||||
await record_token_usage(
|
||||
audit_session,
|
||||
usage_type=usage_type,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cost_micros=cost_micros,
|
||||
model_breakdown=model_breakdown,
|
||||
call_details=call_details,
|
||||
thread_id=audit_thread_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
logger.info(
|
||||
"[billable_call] audit row staged label=%s usage_type=%s user=%s thread=%s",
|
||||
audit_label,
|
||||
usage_type,
|
||||
user_id,
|
||||
audit_thread_id,
|
||||
)
|
||||
await audit_session.commit()
|
||||
logger.info(
|
||||
"[billable_call] audit commit OK label=%s usage_type=%s user=%s thread=%s",
|
||||
audit_label,
|
||||
usage_type,
|
||||
user_id,
|
||||
audit_thread_id,
|
||||
)
|
||||
except BaseException:
|
||||
await _rollback_safely(audit_session)
|
||||
raise
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(_persist(), timeout=timeout_seconds)
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
"[billable_call] audit timed out label=%s usage_type=%s user=%s thread=%s "
|
||||
"timeout=%.1fs total_tokens=%d cost_micros=%d",
|
||||
audit_label,
|
||||
usage_type,
|
||||
user_id,
|
||||
audit_thread_id,
|
||||
timeout_seconds,
|
||||
total_tokens,
|
||||
cost_micros,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"[billable_call] audit failed label=%s usage_type=%s user=%s thread=%s "
|
||||
"total_tokens=%d cost_micros=%d",
|
||||
audit_label,
|
||||
usage_type,
|
||||
user_id,
|
||||
audit_thread_id,
|
||||
total_tokens,
|
||||
cost_micros,
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def billable_call(
|
||||
*,
|
||||
user_id: UUID,
|
||||
search_space_id: int,
|
||||
billing_tier: str,
|
||||
base_model: str,
|
||||
quota_reserve_tokens: int | None = None,
|
||||
quota_reserve_micros_override: int | None = None,
|
||||
usage_type: str,
|
||||
thread_id: int | None = None,
|
||||
message_id: int | None = None,
|
||||
call_details: dict[str, Any] | None = None,
|
||||
billable_session_factory: BillableSessionFactory | None = None,
|
||||
audit_timeout_seconds: float = AUDIT_TIMEOUT_SECONDS,
|
||||
) -> AsyncIterator[TurnTokenAccumulator]:
|
||||
"""Wrap a single billable LLM/image call.
|
||||
|
||||
Args:
|
||||
user_id: Owner of the credit pool to debit. For vision-LLM during
|
||||
indexing this is the *search-space owner* (issue M), not the
|
||||
triggering user.
|
||||
search_space_id: Required — recorded on the ``TokenUsage`` audit row.
|
||||
billing_tier: ``"premium"`` debits; anything else (``"free"``) skips
|
||||
the reserve/finalize dance but still records an audit row with
|
||||
the captured cost.
|
||||
base_model: Used by :func:`estimate_call_reserve_micros` to compute
|
||||
a worst-case reservation from LiteLLM's pricing table.
|
||||
quota_reserve_tokens: Optional per-config override for the chat-style
|
||||
reserve estimator (vision LLM uses this).
|
||||
quota_reserve_micros_override: Optional flat micro-USD reservation
|
||||
(image generation uses this — its cost shape is per-image, not
|
||||
per-token).
|
||||
usage_type: ``"image_generation"`` / ``"vision_extraction"`` / etc.
|
||||
Recorded on the ``TokenUsage`` row.
|
||||
thread_id, message_id: Optional FK columns on ``TokenUsage``.
|
||||
call_details: Optional per-call metadata (model name, parameters)
|
||||
forwarded to ``record_token_usage``.
|
||||
billable_session_factory: Optional async context factory used for
|
||||
reserve/finalize/release/audit sessions. Defaults to
|
||||
``shielded_async_session`` for route callers; Celery callers pass
|
||||
a worker-loop-safe session factory.
|
||||
audit_timeout_seconds: Upper bound for TokenUsage audit persistence.
|
||||
Audit failure is best-effort and does not undo successful
|
||||
settlement.
|
||||
|
||||
Yields:
|
||||
The ``TurnTokenAccumulator`` scoped to this call. The caller invokes
|
||||
the underlying LLM/image API while inside the ``async with``; the
|
||||
``TokenTrackingCallback`` populates the accumulator automatically.
|
||||
|
||||
Raises:
|
||||
QuotaInsufficientError: when premium and ``premium_reserve`` denies.
|
||||
"""
|
||||
is_premium = billing_tier == "premium"
|
||||
session_factory = billable_session_factory or shielded_async_session
|
||||
|
||||
async with scoped_turn() as acc:
|
||||
# ---------- Free path: just audit -------------------------------
|
||||
if not is_premium:
|
||||
try:
|
||||
yield acc
|
||||
finally:
|
||||
# Always audit, even on exception, so we capture cost when
|
||||
# provider returns successfully but the caller raises later.
|
||||
await _record_audit_best_effort(
|
||||
session_factory=session_factory,
|
||||
usage_type=usage_type,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
prompt_tokens=acc.total_prompt_tokens,
|
||||
completion_tokens=acc.total_completion_tokens,
|
||||
total_tokens=acc.grand_total,
|
||||
cost_micros=acc.total_cost_micros,
|
||||
model_breakdown=acc.per_message_summary(),
|
||||
call_details=call_details,
|
||||
thread_id=thread_id,
|
||||
message_id=message_id,
|
||||
audit_label="free",
|
||||
timeout_seconds=audit_timeout_seconds,
|
||||
)
|
||||
return
|
||||
|
||||
# ---------- Premium path: reserve → execute → finalize ----------
|
||||
if quota_reserve_micros_override is not None:
|
||||
reserve_micros = max(1, int(quota_reserve_micros_override))
|
||||
else:
|
||||
reserve_micros = estimate_call_reserve_micros(
|
||||
base_model=base_model or "",
|
||||
quota_reserve_tokens=quota_reserve_tokens,
|
||||
)
|
||||
|
||||
request_id = str(uuid4())
|
||||
|
||||
async with session_factory() as quota_session:
|
||||
reserve_result = await TokenQuotaService.premium_reserve(
|
||||
db_session=quota_session,
|
||||
user_id=user_id,
|
||||
request_id=request_id,
|
||||
reserve_micros=reserve_micros,
|
||||
)
|
||||
|
||||
if not reserve_result.allowed:
|
||||
logger.info(
|
||||
"[billable_call] reserve DENIED user=%s usage_type=%s "
|
||||
"reserve=%d used=%d limit=%d remaining=%d",
|
||||
user_id,
|
||||
usage_type,
|
||||
reserve_micros,
|
||||
reserve_result.used,
|
||||
reserve_result.limit,
|
||||
reserve_result.remaining,
|
||||
)
|
||||
raise QuotaInsufficientError(
|
||||
usage_type=usage_type,
|
||||
used_micros=reserve_result.used,
|
||||
limit_micros=reserve_result.limit,
|
||||
remaining_micros=reserve_result.remaining,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"[billable_call] reserve OK user=%s usage_type=%s reserve_micros=%d "
|
||||
"(remaining=%d)",
|
||||
user_id,
|
||||
usage_type,
|
||||
reserve_micros,
|
||||
reserve_result.remaining,
|
||||
)
|
||||
|
||||
try:
|
||||
yield acc
|
||||
except BaseException:
|
||||
# Release on any failure (including QuotaInsufficientError raised
|
||||
# from a downstream call, asyncio cancellation, etc.). We use
|
||||
# BaseException so cancellation also releases.
|
||||
try:
|
||||
async with session_factory() as quota_session:
|
||||
await TokenQuotaService.premium_release(
|
||||
db_session=quota_session,
|
||||
user_id=user_id,
|
||||
reserved_micros=reserve_micros,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"[billable_call] premium_release failed for user=%s "
|
||||
"reserve_micros=%d (reservation will be GC'd by quota "
|
||||
"reconciliation if/when implemented)",
|
||||
user_id,
|
||||
reserve_micros,
|
||||
)
|
||||
raise
|
||||
|
||||
# ---------- Success: finalize + audit ----------------------------
|
||||
actual_micros = acc.total_cost_micros
|
||||
try:
|
||||
logger.info(
|
||||
"[billable_call] finalize start user=%s usage_type=%s actual=%d "
|
||||
"reserved=%d thread=%s",
|
||||
user_id,
|
||||
usage_type,
|
||||
actual_micros,
|
||||
reserve_micros,
|
||||
thread_id,
|
||||
)
|
||||
async with session_factory() as quota_session:
|
||||
final_result = await TokenQuotaService.premium_finalize(
|
||||
db_session=quota_session,
|
||||
user_id=user_id,
|
||||
request_id=request_id,
|
||||
actual_micros=actual_micros,
|
||||
reserved_micros=reserve_micros,
|
||||
)
|
||||
logger.info(
|
||||
"[billable_call] finalize user=%s usage_type=%s actual=%d "
|
||||
"reserved=%d → used=%d/%d (remaining=%d)",
|
||||
user_id,
|
||||
usage_type,
|
||||
actual_micros,
|
||||
reserve_micros,
|
||||
final_result.used,
|
||||
final_result.limit,
|
||||
final_result.remaining,
|
||||
)
|
||||
except Exception as finalize_exc:
|
||||
# Last-ditch: if finalize itself fails, we must at least release
|
||||
# so the reservation doesn't leak.
|
||||
logger.exception(
|
||||
"[billable_call] premium_finalize failed for user=%s; "
|
||||
"attempting release",
|
||||
user_id,
|
||||
)
|
||||
try:
|
||||
async with session_factory() as quota_session:
|
||||
await TokenQuotaService.premium_release(
|
||||
db_session=quota_session,
|
||||
user_id=user_id,
|
||||
reserved_micros=reserve_micros,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"[billable_call] release after finalize failure ALSO failed "
|
||||
"for user=%s",
|
||||
user_id,
|
||||
)
|
||||
raise BillingSettlementError(
|
||||
usage_type=usage_type,
|
||||
user_id=user_id,
|
||||
cause=finalize_exc,
|
||||
) from finalize_exc
|
||||
|
||||
await _record_audit_best_effort(
|
||||
session_factory=session_factory,
|
||||
usage_type=usage_type,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
prompt_tokens=acc.total_prompt_tokens,
|
||||
completion_tokens=acc.total_completion_tokens,
|
||||
total_tokens=acc.grand_total,
|
||||
cost_micros=actual_micros,
|
||||
model_breakdown=acc.per_message_summary(),
|
||||
call_details=call_details,
|
||||
thread_id=thread_id,
|
||||
message_id=message_id,
|
||||
audit_label="premium",
|
||||
timeout_seconds=audit_timeout_seconds,
|
||||
)
|
||||
|
||||
|
||||
async def _resolve_agent_billing_for_search_space(
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
*,
|
||||
thread_id: int | None = None,
|
||||
) -> tuple[UUID, str, str]:
|
||||
"""Resolve ``(owner_user_id, billing_tier, base_model)`` for the search-space
|
||||
agent LLM.
|
||||
|
||||
Used by Celery tasks (podcast generation, video presentation) to bill the
|
||||
search-space owner's premium credit pool when the agent LLM is premium.
|
||||
|
||||
Resolution rules mirror chat at ``stream_new_chat.py:2294-2351``:
|
||||
|
||||
- Search space not found / no ``agent_llm_id``: raise ``ValueError``.
|
||||
- **Auto mode** (``id == AUTO_FASTEST_ID == 0``):
|
||||
* ``thread_id`` is set: delegate to
|
||||
``resolve_or_get_pinned_llm_config_id`` (the same call chat uses) and
|
||||
recurse into the resolved id. Reuses chat's existing pin if present
|
||||
so the same model bills for chat + downstream podcast/video. If the
|
||||
user is not premium-eligible, the pin service auto-restricts to free
|
||||
deployments — denial only happens later in
|
||||
``billable_call.premium_reserve`` if the pin really is premium and
|
||||
credit ran out mid-flow.
|
||||
* ``thread_id`` is None: fallback to ``("free", "auto")``. Forward-compat
|
||||
for any future direct-API path; today both Celery tasks always pass
|
||||
``thread_id``.
|
||||
- **Negative id** (global YAML / OpenRouter): ``cfg["billing_tier"]``
|
||||
(defaults to ``"free"`` via ``app/config/__init__.py:52`` setdefault),
|
||||
``base_model = litellm_params.get("base_model") or model_name`` —
|
||||
NOT provider-prefixed, matching chat's cost-map lookup convention.
|
||||
- **Positive id** (user BYOK ``NewLLMConfig``): always free (matches
|
||||
``AgentConfig.from_new_llm_config`` which hard-codes ``billing_tier="free"``);
|
||||
``base_model`` from ``litellm_params`` or ``model_name``.
|
||||
|
||||
Note on imports: ``llm_service``, ``auto_model_pin_service``, and
|
||||
``llm_router_service`` are imported lazily inside the function body to
|
||||
avoid hoisting litellm side-effects (``litellm.callbacks =
|
||||
[token_tracker]``, ``litellm.drop_params``, etc.) into
|
||||
``billable_calls.py``'s module load path.
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.db import NewLLMConfig, SearchSpace
|
||||
|
||||
result = await session.execute(
|
||||
select(SearchSpace).where(SearchSpace.id == search_space_id)
|
||||
)
|
||||
search_space = result.scalars().first()
|
||||
if search_space is None:
|
||||
raise ValueError(f"Search space {search_space_id} not found")
|
||||
|
||||
agent_llm_id = search_space.agent_llm_id
|
||||
if agent_llm_id is None:
|
||||
raise ValueError(
|
||||
f"Search space {search_space_id} has no agent_llm_id configured"
|
||||
)
|
||||
|
||||
owner_user_id: UUID = search_space.user_id
|
||||
|
||||
from app.services.auto_model_pin_service import (
|
||||
AUTO_FASTEST_ID,
|
||||
resolve_or_get_pinned_llm_config_id,
|
||||
)
|
||||
|
||||
if agent_llm_id == AUTO_FASTEST_ID:
|
||||
if thread_id is None:
|
||||
return owner_user_id, "free", "auto"
|
||||
try:
|
||||
resolution = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=thread_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=str(owner_user_id),
|
||||
selected_llm_config_id=AUTO_FASTEST_ID,
|
||||
)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"[agent_billing] Auto-mode pin resolution failed for "
|
||||
"search_space=%s thread=%s; falling back to free",
|
||||
search_space_id,
|
||||
thread_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return owner_user_id, "free", "auto"
|
||||
agent_llm_id = resolution.resolved_llm_config_id
|
||||
|
||||
if agent_llm_id < 0:
|
||||
from app.services.llm_service import get_global_llm_config
|
||||
|
||||
cfg = get_global_llm_config(agent_llm_id) or {}
|
||||
billing_tier = str(cfg.get("billing_tier", "free")).lower()
|
||||
litellm_params = cfg.get("litellm_params") or {}
|
||||
base_model = litellm_params.get("base_model") or cfg.get("model_name") or ""
|
||||
return owner_user_id, billing_tier, base_model
|
||||
|
||||
nlc_result = await session.execute(
|
||||
select(NewLLMConfig).where(
|
||||
NewLLMConfig.id == agent_llm_id,
|
||||
NewLLMConfig.search_space_id == search_space_id,
|
||||
)
|
||||
)
|
||||
nlc = nlc_result.scalars().first()
|
||||
base_model = ""
|
||||
if nlc is not None:
|
||||
litellm_params = nlc.litellm_params or {}
|
||||
base_model = litellm_params.get("base_model") or nlc.model_name or ""
|
||||
return owner_user_id, "free", base_model
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BillingSettlementError",
|
||||
"QuotaInsufficientError",
|
||||
"_resolve_agent_billing_for_search_space",
|
||||
"billable_call",
|
||||
]
|
||||
|
||||
|
||||
# Re-export the config knob so callers don't have to import config just for
|
||||
# the default image reserve.
|
||||
DEFAULT_IMAGE_RESERVE_MICROS = config.QUOTA_DEFAULT_IMAGE_RESERVE_MICROS
|
||||
|
|
@ -408,12 +408,37 @@ class ComposioService:
|
|||
files = []
|
||||
next_token = None
|
||||
if isinstance(data, dict):
|
||||
inner_data = data.get("data", data)
|
||||
response_data = (
|
||||
inner_data.get("response_data", {})
|
||||
if isinstance(inner_data, dict)
|
||||
else {}
|
||||
)
|
||||
# Try direct access first, then nested
|
||||
files = data.get("files", []) or data.get("data", {}).get("files", [])
|
||||
files = (
|
||||
data.get("files", [])
|
||||
or (
|
||||
inner_data.get("files", [])
|
||||
if isinstance(inner_data, dict)
|
||||
else []
|
||||
)
|
||||
or response_data.get("files", [])
|
||||
)
|
||||
next_token = (
|
||||
data.get("nextPageToken")
|
||||
or data.get("next_page_token")
|
||||
or data.get("data", {}).get("nextPageToken")
|
||||
or (
|
||||
inner_data.get("nextPageToken")
|
||||
if isinstance(inner_data, dict)
|
||||
else None
|
||||
)
|
||||
or (
|
||||
inner_data.get("next_page_token")
|
||||
if isinstance(inner_data, dict)
|
||||
else None
|
||||
)
|
||||
or response_data.get("nextPageToken")
|
||||
or response_data.get("next_page_token")
|
||||
)
|
||||
elif isinstance(data, list):
|
||||
files = data
|
||||
|
|
@ -819,24 +844,61 @@ class ComposioService:
|
|||
next_token = None
|
||||
result_size_estimate = None
|
||||
if isinstance(data, dict):
|
||||
inner_data = data.get("data", data)
|
||||
response_data = (
|
||||
inner_data.get("response_data", {})
|
||||
if isinstance(inner_data, dict)
|
||||
else {}
|
||||
)
|
||||
messages = (
|
||||
data.get("messages", [])
|
||||
or data.get("data", {}).get("messages", [])
|
||||
or (
|
||||
inner_data.get("messages", [])
|
||||
if isinstance(inner_data, dict)
|
||||
else []
|
||||
)
|
||||
or response_data.get("messages", [])
|
||||
or data.get("emails", [])
|
||||
or (
|
||||
inner_data.get("emails", [])
|
||||
if isinstance(inner_data, dict)
|
||||
else []
|
||||
)
|
||||
or response_data.get("emails", [])
|
||||
)
|
||||
# Check for pagination token in various possible locations
|
||||
next_token = (
|
||||
data.get("nextPageToken")
|
||||
or data.get("next_page_token")
|
||||
or data.get("data", {}).get("nextPageToken")
|
||||
or data.get("data", {}).get("next_page_token")
|
||||
or (
|
||||
inner_data.get("nextPageToken")
|
||||
if isinstance(inner_data, dict)
|
||||
else None
|
||||
)
|
||||
or (
|
||||
inner_data.get("next_page_token")
|
||||
if isinstance(inner_data, dict)
|
||||
else None
|
||||
)
|
||||
or response_data.get("nextPageToken")
|
||||
or response_data.get("next_page_token")
|
||||
)
|
||||
# Extract resultSizeEstimate if available (Gmail API provides this)
|
||||
result_size_estimate = (
|
||||
data.get("resultSizeEstimate")
|
||||
or data.get("result_size_estimate")
|
||||
or data.get("data", {}).get("resultSizeEstimate")
|
||||
or data.get("data", {}).get("result_size_estimate")
|
||||
or (
|
||||
inner_data.get("resultSizeEstimate")
|
||||
if isinstance(inner_data, dict)
|
||||
else None
|
||||
)
|
||||
or (
|
||||
inner_data.get("result_size_estimate")
|
||||
if isinstance(inner_data, dict)
|
||||
else None
|
||||
)
|
||||
or response_data.get("resultSizeEstimate")
|
||||
or response_data.get("result_size_estimate")
|
||||
)
|
||||
elif isinstance(data, list):
|
||||
messages = data
|
||||
|
|
@ -864,7 +926,7 @@ class ComposioService:
|
|||
try:
|
||||
result = await self.execute_tool(
|
||||
connected_account_id=connected_account_id,
|
||||
tool_name="GMAIL_GET_MESSAGE_BY_MESSAGE_ID",
|
||||
tool_name="GMAIL_FETCH_MESSAGE_BY_MESSAGE_ID",
|
||||
params={"message_id": message_id}, # snake_case
|
||||
entity_id=entity_id,
|
||||
)
|
||||
|
|
@ -872,7 +934,13 @@ class ComposioService:
|
|||
if not result.get("success"):
|
||||
return None, result.get("error", "Unknown error")
|
||||
|
||||
return result.get("data"), None
|
||||
data = result.get("data")
|
||||
if isinstance(data, dict):
|
||||
inner_data = data.get("data", data)
|
||||
if isinstance(inner_data, dict):
|
||||
return inner_data.get("response_data", inner_data), None
|
||||
|
||||
return data, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get Gmail message detail: {e!s}")
|
||||
|
|
@ -928,10 +996,27 @@ class ComposioService:
|
|||
# Try different possible response structures
|
||||
events = []
|
||||
if isinstance(data, dict):
|
||||
inner_data = data.get("data", data)
|
||||
response_data = (
|
||||
inner_data.get("response_data", {})
|
||||
if isinstance(inner_data, dict)
|
||||
else {}
|
||||
)
|
||||
events = (
|
||||
data.get("items", [])
|
||||
or data.get("data", {}).get("items", [])
|
||||
or (
|
||||
inner_data.get("items", [])
|
||||
if isinstance(inner_data, dict)
|
||||
else []
|
||||
)
|
||||
or response_data.get("items", [])
|
||||
or data.get("events", [])
|
||||
or (
|
||||
inner_data.get("events", [])
|
||||
if isinstance(inner_data, dict)
|
||||
else []
|
||||
)
|
||||
or response_data.get("events", [])
|
||||
)
|
||||
elif isinstance(data, list):
|
||||
events = data
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
|
@ -2769,12 +2771,22 @@ class ConnectorService:
|
|||
"""
|
||||
Get all available (enabled) connector types for a search space.
|
||||
|
||||
Phase 1.4: results are cached per ``search_space_id`` for
|
||||
:data:`_DISCOVERY_TTL_SECONDS`. Cache key is independent of session
|
||||
identity — the cached value is plain data, safe to share across
|
||||
requests. Invalidate on connector add/update/delete via
|
||||
:func:`invalidate_connector_discovery_cache`.
|
||||
|
||||
Args:
|
||||
search_space_id: The search space ID
|
||||
|
||||
Returns:
|
||||
List of SearchSourceConnectorType enums for enabled connectors
|
||||
"""
|
||||
cached = _get_cached_connectors(search_space_id)
|
||||
if cached is not None:
|
||||
return list(cached)
|
||||
|
||||
query = (
|
||||
select(SearchSourceConnector.connector_type)
|
||||
.filter(
|
||||
|
|
@ -2784,8 +2796,9 @@ class ConnectorService:
|
|||
)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
connector_types = result.scalars().all()
|
||||
return list(connector_types)
|
||||
connector_types = list(result.scalars().all())
|
||||
_set_cached_connectors(search_space_id, connector_types)
|
||||
return connector_types
|
||||
|
||||
async def get_available_document_types(
|
||||
self,
|
||||
|
|
@ -2794,12 +2807,22 @@ class ConnectorService:
|
|||
"""
|
||||
Get all document types that have at least one document in the search space.
|
||||
|
||||
Phase 1.4: cached per ``search_space_id`` for
|
||||
:data:`_DISCOVERY_TTL_SECONDS`. Invalidate via
|
||||
:func:`invalidate_connector_discovery_cache` when a connector
|
||||
finishes indexing new documents (or document types are otherwise
|
||||
added/removed).
|
||||
|
||||
Args:
|
||||
search_space_id: The search space ID
|
||||
|
||||
Returns:
|
||||
List of document type strings that have documents indexed
|
||||
"""
|
||||
cached = _get_cached_doc_types(search_space_id)
|
||||
if cached is not None:
|
||||
return list(cached)
|
||||
|
||||
from sqlalchemy import distinct
|
||||
|
||||
from app.db import Document
|
||||
|
|
@ -2809,5 +2832,164 @@ class ConnectorService:
|
|||
)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
doc_types = result.scalars().all()
|
||||
return [str(dt) for dt in doc_types]
|
||||
doc_types = [str(dt) for dt in result.scalars().all()]
|
||||
_set_cached_doc_types(search_space_id, doc_types)
|
||||
return doc_types
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Connector / document-type discovery TTL cache (Phase 1.4)
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# Both ``get_available_connectors`` and ``get_available_document_types`` are
|
||||
# called on EVERY chat turn from ``create_surfsense_deep_agent``. Each query
|
||||
# hits Postgres and contributes to per-turn agent build latency. Their
|
||||
# results change infrequently — only when the user adds/edits/removes a
|
||||
# connector, or when an indexer commits a new document type. A short TTL
|
||||
# cache (default 30s, env-tunable) collapses N concurrent calls into one
|
||||
# DB roundtrip with bounded staleness.
|
||||
#
|
||||
# Invalidation: connector mutation routes (create / update / delete) call
|
||||
# ``invalidate_connector_discovery_cache(search_space_id)`` to clear the
|
||||
# entry for the affected space. Multi-replica deployments still pay one
|
||||
# DB roundtrip per replica per TTL window, which is fine — staleness is
|
||||
# bounded and the alternative (cross-replica fanout) is not worth the
|
||||
# coupling here.
|
||||
|
||||
_DISCOVERY_TTL_SECONDS: float = float(
|
||||
os.getenv("SURFSENSE_CONNECTOR_DISCOVERY_TTL_SECONDS", "30")
|
||||
)
|
||||
|
||||
# Per-search-space caches. Keyed by ``search_space_id``; value is
|
||||
# ``(expires_at_monotonic, payload)``. Plain dicts protected by a lock —
|
||||
# read-mostly workload, sub-microsecond contention.
|
||||
_connectors_cache: dict[int, tuple[float, list[SearchSourceConnectorType]]] = {}
|
||||
_doc_types_cache: dict[int, tuple[float, list[str]]] = {}
|
||||
_cache_lock = Lock()
|
||||
|
||||
|
||||
def _get_cached_connectors(
|
||||
search_space_id: int,
|
||||
) -> list[SearchSourceConnectorType] | None:
|
||||
if _DISCOVERY_TTL_SECONDS <= 0:
|
||||
return None
|
||||
with _cache_lock:
|
||||
entry = _connectors_cache.get(search_space_id)
|
||||
if entry is None:
|
||||
return None
|
||||
expires_at, payload = entry
|
||||
if time.monotonic() >= expires_at:
|
||||
_connectors_cache.pop(search_space_id, None)
|
||||
return None
|
||||
return payload
|
||||
|
||||
|
||||
def _set_cached_connectors(
|
||||
search_space_id: int, payload: list[SearchSourceConnectorType]
|
||||
) -> None:
|
||||
if _DISCOVERY_TTL_SECONDS <= 0:
|
||||
return
|
||||
expires_at = time.monotonic() + _DISCOVERY_TTL_SECONDS
|
||||
with _cache_lock:
|
||||
_connectors_cache[search_space_id] = (expires_at, list(payload))
|
||||
|
||||
|
||||
def _get_cached_doc_types(search_space_id: int) -> list[str] | None:
|
||||
if _DISCOVERY_TTL_SECONDS <= 0:
|
||||
return None
|
||||
with _cache_lock:
|
||||
entry = _doc_types_cache.get(search_space_id)
|
||||
if entry is None:
|
||||
return None
|
||||
expires_at, payload = entry
|
||||
if time.monotonic() >= expires_at:
|
||||
_doc_types_cache.pop(search_space_id, None)
|
||||
return None
|
||||
return payload
|
||||
|
||||
|
||||
def _set_cached_doc_types(search_space_id: int, payload: list[str]) -> None:
|
||||
if _DISCOVERY_TTL_SECONDS <= 0:
|
||||
return
|
||||
expires_at = time.monotonic() + _DISCOVERY_TTL_SECONDS
|
||||
with _cache_lock:
|
||||
_doc_types_cache[search_space_id] = (expires_at, list(payload))
|
||||
|
||||
|
||||
def invalidate_connector_discovery_cache(search_space_id: int | None = None) -> None:
|
||||
"""Drop cached discovery results for ``search_space_id`` (or all spaces).
|
||||
|
||||
Connector CRUD routes / indexer pipelines call this when they mutate
|
||||
the rows backing :func:`ConnectorService.get_available_connectors` /
|
||||
:func:`get_available_document_types`. ``None`` clears every space —
|
||||
useful in tests and on bulk imports.
|
||||
"""
|
||||
with _cache_lock:
|
||||
if search_space_id is None:
|
||||
_connectors_cache.clear()
|
||||
_doc_types_cache.clear()
|
||||
else:
|
||||
_connectors_cache.pop(search_space_id, None)
|
||||
_doc_types_cache.pop(search_space_id, None)
|
||||
|
||||
|
||||
def _invalidate_connectors_only(search_space_id: int | None = None) -> None:
|
||||
with _cache_lock:
|
||||
if search_space_id is None:
|
||||
_connectors_cache.clear()
|
||||
else:
|
||||
_connectors_cache.pop(search_space_id, None)
|
||||
|
||||
|
||||
def _invalidate_doc_types_only(search_space_id: int | None = None) -> None:
|
||||
with _cache_lock:
|
||||
if search_space_id is None:
|
||||
_doc_types_cache.clear()
|
||||
else:
|
||||
_doc_types_cache.pop(search_space_id, None)
|
||||
|
||||
|
||||
def _register_invalidation_listeners() -> None:
|
||||
"""Wire SQLAlchemy ORM events so cache stays consistent automatically.
|
||||
|
||||
Listening on ``after_insert`` / ``after_update`` / ``after_delete``
|
||||
means every successful INSERT/UPDATE/DELETE that goes through the ORM
|
||||
invalidates the affected search space's cached discovery payload —
|
||||
no need to sprinkle ``invalidate_*`` calls across 30+ connector
|
||||
routes. Bulk operations that bypass the ORM (e.g.
|
||||
``session.execute(insert(...))`` without a mapped object) still need
|
||||
explicit invalidation; document indexers already commit through the
|
||||
ORM so document-type discovery is covered.
|
||||
"""
|
||||
from sqlalchemy import event
|
||||
|
||||
# Imported here (not at module top) to avoid a circular import:
|
||||
# app.services.connector_service is itself imported from app.db's
|
||||
# ecosystem indirectly via several CRUD modules.
|
||||
from app.db import Document, SearchSourceConnector
|
||||
|
||||
def _connector_changed(_mapper, _connection, target) -> None:
|
||||
sid = getattr(target, "search_space_id", None)
|
||||
if sid is not None:
|
||||
_invalidate_connectors_only(int(sid))
|
||||
|
||||
def _document_changed(_mapper, _connection, target) -> None:
|
||||
sid = getattr(target, "search_space_id", None)
|
||||
if sid is not None:
|
||||
_invalidate_doc_types_only(int(sid))
|
||||
|
||||
for evt in ("after_insert", "after_update", "after_delete"):
|
||||
event.listen(SearchSourceConnector, evt, _connector_changed)
|
||||
event.listen(Document, evt, _document_changed)
|
||||
|
||||
|
||||
try:
|
||||
_register_invalidation_listeners()
|
||||
except Exception: # pragma: no cover - defensive; never block module import
|
||||
import logging as _logging
|
||||
|
||||
_logging.getLogger(__name__).exception(
|
||||
"Failed to register connector discovery cache invalidation listeners; "
|
||||
"stale cache risk: explicit invalidate_connector_discovery_cache calls "
|
||||
"may be required."
|
||||
)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from app.db import (
|
|||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
)
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
from app.services.composio_service import ComposioService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -78,14 +78,49 @@ class GmailToolMetadataService:
|
|||
def __init__(self, db_session: AsyncSession):
|
||||
self._db_session = db_session
|
||||
|
||||
async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials:
|
||||
if (
|
||||
def _is_composio_connector(self, connector: SearchSourceConnector) -> bool:
|
||||
return (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
):
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
return build_composio_credentials(cca_id)
|
||||
)
|
||||
|
||||
def _get_composio_connected_account_id(
|
||||
self, connector: SearchSourceConnector
|
||||
) -> str:
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if not cca_id:
|
||||
raise ValueError("Composio connected_account_id not found")
|
||||
return cca_id
|
||||
|
||||
def _unwrap_composio_data(self, data: Any) -> Any:
|
||||
if isinstance(data, dict):
|
||||
inner = data.get("data", data)
|
||||
if isinstance(inner, dict):
|
||||
return inner.get("response_data", inner)
|
||||
return inner
|
||||
return data
|
||||
|
||||
async def _execute_composio_gmail_tool(
|
||||
self,
|
||||
connector: SearchSourceConnector,
|
||||
tool_name: str,
|
||||
params: dict[str, Any],
|
||||
) -> tuple[Any, str | None]:
|
||||
result = await ComposioService().execute_tool(
|
||||
connected_account_id=self._get_composio_connected_account_id(connector),
|
||||
tool_name=tool_name,
|
||||
params=params,
|
||||
entity_id=f"surfsense_{connector.user_id}",
|
||||
)
|
||||
if not result.get("success"):
|
||||
return None, result.get("error", "Unknown Composio Gmail error")
|
||||
return self._unwrap_composio_data(result.get("data")), None
|
||||
|
||||
async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials:
|
||||
if self._is_composio_connector(connector):
|
||||
raise ValueError(
|
||||
"Composio Gmail connectors must use Composio tool execution"
|
||||
)
|
||||
|
||||
config_data = dict(connector.config)
|
||||
|
||||
|
|
@ -139,6 +174,12 @@ class GmailToolMetadataService:
|
|||
if not connector:
|
||||
return True
|
||||
|
||||
if self._is_composio_connector(connector):
|
||||
_profile, error = await self._execute_composio_gmail_tool(
|
||||
connector, "GMAIL_GET_PROFILE", {"user_id": "me"}
|
||||
)
|
||||
return bool(error)
|
||||
|
||||
creds = await self._build_credentials(connector)
|
||||
service = build("gmail", "v1", credentials=creds)
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
|
|
@ -221,14 +262,21 @@ class GmailToolMetadataService:
|
|||
)
|
||||
connector = result.scalar_one_or_none()
|
||||
if connector:
|
||||
creds = await self._build_credentials(connector)
|
||||
service = build("gmail", "v1", credentials=creds)
|
||||
profile = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda service=service: (
|
||||
service.users().getProfile(userId="me").execute()
|
||||
),
|
||||
)
|
||||
if self._is_composio_connector(connector):
|
||||
profile, error = await self._execute_composio_gmail_tool(
|
||||
connector, "GMAIL_GET_PROFILE", {"user_id": "me"}
|
||||
)
|
||||
if error:
|
||||
raise RuntimeError(error)
|
||||
else:
|
||||
creds = await self._build_credentials(connector)
|
||||
service = build("gmail", "v1", credentials=creds)
|
||||
profile = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda service=service: (
|
||||
service.users().getProfile(userId="me").execute()
|
||||
),
|
||||
)
|
||||
acc_dict["email"] = profile.get("emailAddress", "")
|
||||
except Exception:
|
||||
logger.warning(
|
||||
|
|
@ -298,6 +346,23 @@ class GmailToolMetadataService:
|
|||
Returns ``None`` on any failure so callers can degrade gracefully.
|
||||
"""
|
||||
try:
|
||||
if self._is_composio_connector(connector):
|
||||
if not draft_id:
|
||||
draft_id = await self._find_composio_draft_id(connector, message_id)
|
||||
if not draft_id:
|
||||
return None
|
||||
|
||||
draft, error = await self._execute_composio_gmail_tool(
|
||||
connector,
|
||||
"GMAIL_GET_DRAFT",
|
||||
{"user_id": "me", "draft_id": draft_id, "format": "full"},
|
||||
)
|
||||
if error or not isinstance(draft, dict):
|
||||
return None
|
||||
|
||||
payload = draft.get("message", {}).get("payload", {})
|
||||
return self._extract_body_from_payload(payload)
|
||||
|
||||
creds = await self._build_credentials(connector)
|
||||
service = build("gmail", "v1", credentials=creds)
|
||||
|
||||
|
|
@ -326,6 +391,33 @@ class GmailToolMetadataService:
|
|||
)
|
||||
return None
|
||||
|
||||
async def _find_composio_draft_id(
|
||||
self, connector: SearchSourceConnector, message_id: str
|
||||
) -> str | None:
|
||||
page_token = ""
|
||||
while True:
|
||||
params: dict[str, Any] = {
|
||||
"user_id": "me",
|
||||
"max_results": 100,
|
||||
"verbose": False,
|
||||
}
|
||||
if page_token:
|
||||
params["page_token"] = page_token
|
||||
|
||||
data, error = await self._execute_composio_gmail_tool(
|
||||
connector, "GMAIL_LIST_DRAFTS", params
|
||||
)
|
||||
if error or not isinstance(data, dict):
|
||||
return None
|
||||
|
||||
for draft in data.get("drafts", []):
|
||||
if draft.get("message", {}).get("id") == message_id:
|
||||
return draft.get("id")
|
||||
|
||||
page_token = data.get("nextPageToken") or data.get("next_page_token") or ""
|
||||
if not page_token:
|
||||
return None
|
||||
|
||||
async def _find_draft_id(self, service: Any, message_id: str) -> str | None:
|
||||
"""Resolve a draft ID from its message ID by scanning drafts.list."""
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from app.db import (
|
|||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
)
|
||||
from app.services.composio_service import ComposioService
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
|
|
@ -21,7 +22,6 @@ from app.utils.document_converters import (
|
|||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
)
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -203,23 +203,46 @@ class GoogleCalendarKBSyncService:
|
|||
logger.warning("Document %s not found in KB", document_id)
|
||||
return {"status": "not_indexed"}
|
||||
|
||||
creds = await self._build_credentials_for_connector(connector_id)
|
||||
loop = asyncio.get_event_loop()
|
||||
service = await loop.run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
|
||||
calendar_id = (document.document_metadata or {}).get(
|
||||
"calendar_id"
|
||||
) or "primary"
|
||||
live_event = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.get(calendarId=calendar_id, eventId=event_id)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
connector = await self._get_connector(connector_id)
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||
):
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if not cca_id:
|
||||
raise ValueError("Composio connected_account_id not found")
|
||||
composio_result = await ComposioService().execute_tool(
|
||||
connected_account_id=cca_id,
|
||||
tool_name="GOOGLECALENDAR_EVENTS_GET",
|
||||
params={"calendar_id": calendar_id, "event_id": event_id},
|
||||
entity_id=f"surfsense_{user_id}",
|
||||
)
|
||||
if not composio_result.get("success"):
|
||||
raise RuntimeError(
|
||||
composio_result.get("error", "Unknown Composio Calendar error")
|
||||
)
|
||||
live_event = composio_result.get("data", {})
|
||||
if isinstance(live_event, dict):
|
||||
live_event = live_event.get("data", live_event)
|
||||
if isinstance(live_event, dict):
|
||||
live_event = live_event.get("response_data", live_event)
|
||||
else:
|
||||
creds = await self._build_credentials_for_connector(connector_id)
|
||||
loop = asyncio.get_event_loop()
|
||||
service = await loop.run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
live_event = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.get(calendarId=calendar_id, eventId=event_id)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
|
||||
event_summary = live_event.get("summary", "")
|
||||
description = live_event.get("description", "")
|
||||
|
|
@ -322,7 +345,7 @@ class GoogleCalendarKBSyncService:
|
|||
await self.db_session.rollback()
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
async def _build_credentials_for_connector(self, connector_id: int) -> Credentials:
|
||||
async def _get_connector(self, connector_id: int) -> SearchSourceConnector:
|
||||
result = await self.db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == connector_id
|
||||
|
|
@ -331,15 +354,17 @@ class GoogleCalendarKBSyncService:
|
|||
connector = result.scalar_one_or_none()
|
||||
if not connector:
|
||||
raise ValueError(f"Connector {connector_id} not found")
|
||||
return connector
|
||||
|
||||
async def _build_credentials_for_connector(self, connector_id: int) -> Credentials:
|
||||
connector = await self._get_connector(connector_id)
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||
):
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
return build_composio_credentials(cca_id)
|
||||
raise ValueError("Composio connected_account_id not found")
|
||||
raise ValueError(
|
||||
"Composio Calendar connectors must use Composio tool execution"
|
||||
)
|
||||
|
||||
config_data = dict(connector.config)
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from app.db import (
|
|||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
)
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
from app.services.composio_service import ComposioService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -94,15 +94,49 @@ class GoogleCalendarToolMetadataService:
|
|||
def __init__(self, db_session: AsyncSession):
|
||||
self._db_session = db_session
|
||||
|
||||
async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials:
|
||||
if (
|
||||
def _is_composio_connector(self, connector: SearchSourceConnector) -> bool:
|
||||
return (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||
):
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
return build_composio_credentials(cca_id)
|
||||
)
|
||||
|
||||
def _get_composio_connected_account_id(
|
||||
self, connector: SearchSourceConnector
|
||||
) -> str:
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if not cca_id:
|
||||
raise ValueError("Composio connected_account_id not found")
|
||||
return cca_id
|
||||
|
||||
async def _execute_composio_calendar_tool(
|
||||
self,
|
||||
connector: SearchSourceConnector,
|
||||
tool_name: str,
|
||||
params: dict,
|
||||
) -> tuple[dict | list | None, str | None]:
|
||||
service = ComposioService()
|
||||
result = await service.execute_tool(
|
||||
connected_account_id=self._get_composio_connected_account_id(connector),
|
||||
tool_name=tool_name,
|
||||
params=params,
|
||||
entity_id=f"surfsense_{connector.user_id}",
|
||||
)
|
||||
if not result.get("success"):
|
||||
return None, result.get("error", "Unknown Composio Calendar error")
|
||||
|
||||
data = result.get("data")
|
||||
if isinstance(data, dict):
|
||||
inner = data.get("data", data)
|
||||
if isinstance(inner, dict):
|
||||
return inner.get("response_data", inner), None
|
||||
return inner, None
|
||||
return data, None
|
||||
|
||||
async def _build_credentials(self, connector: SearchSourceConnector) -> Credentials:
|
||||
if self._is_composio_connector(connector):
|
||||
raise ValueError(
|
||||
"Composio Calendar connectors must use Composio tool execution"
|
||||
)
|
||||
|
||||
config_data = dict(connector.config)
|
||||
|
||||
|
|
@ -156,6 +190,14 @@ class GoogleCalendarToolMetadataService:
|
|||
if not connector:
|
||||
return True
|
||||
|
||||
if self._is_composio_connector(connector):
|
||||
_data, error = await self._execute_composio_calendar_tool(
|
||||
connector,
|
||||
"GOOGLECALENDAR_GET_CALENDAR",
|
||||
{"calendar_id": "primary"},
|
||||
)
|
||||
return bool(error)
|
||||
|
||||
creds = await self._build_credentials(connector)
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
|
|
@ -255,16 +297,48 @@ class GoogleCalendarToolMetadataService:
|
|||
timezone_str = ""
|
||||
if connector:
|
||||
try:
|
||||
creds = await self._build_credentials(connector)
|
||||
loop = asyncio.get_event_loop()
|
||||
service = await loop.run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
if self._is_composio_connector(connector):
|
||||
cal_list, cal_error = await self._execute_composio_calendar_tool(
|
||||
connector, "GOOGLECALENDAR_LIST_CALENDARS", {}
|
||||
)
|
||||
if cal_error:
|
||||
raise RuntimeError(cal_error)
|
||||
(
|
||||
settings,
|
||||
settings_error,
|
||||
) = await self._execute_composio_calendar_tool(
|
||||
connector,
|
||||
"GOOGLECALENDAR_SETTINGS_GET",
|
||||
{"setting": "timezone"},
|
||||
)
|
||||
if not settings_error and isinstance(settings, dict):
|
||||
timezone_str = settings.get("value", "")
|
||||
else:
|
||||
creds = await self._build_credentials(connector)
|
||||
loop = asyncio.get_event_loop()
|
||||
service = await loop.run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
|
||||
cal_list = await loop.run_in_executor(
|
||||
None, lambda: service.calendarList().list().execute()
|
||||
)
|
||||
for cal in cal_list.get("items", []):
|
||||
cal_list = await loop.run_in_executor(
|
||||
None, lambda: service.calendarList().list().execute()
|
||||
)
|
||||
|
||||
tz_setting = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: service.settings().get(setting="timezone").execute(),
|
||||
)
|
||||
timezone_str = tz_setting.get("value", "")
|
||||
|
||||
calendar_items = []
|
||||
if isinstance(cal_list, dict):
|
||||
calendar_items = (
|
||||
cal_list.get("items") or cal_list.get("calendars") or []
|
||||
)
|
||||
elif isinstance(cal_list, list):
|
||||
calendar_items = cal_list
|
||||
|
||||
for cal in calendar_items:
|
||||
calendars.append(
|
||||
{
|
||||
"id": cal.get("id", ""),
|
||||
|
|
@ -272,12 +346,6 @@ class GoogleCalendarToolMetadataService:
|
|||
"primary": cal.get("primary", False),
|
||||
}
|
||||
)
|
||||
|
||||
tz_setting = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: service.settings().get(setting="timezone").execute(),
|
||||
)
|
||||
timezone_str = tz_setting.get("value", "")
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to fetch calendars/timezone for connector %s",
|
||||
|
|
@ -321,20 +389,29 @@ class GoogleCalendarToolMetadataService:
|
|||
|
||||
event_dict = event.to_dict()
|
||||
try:
|
||||
creds = await self._build_credentials(connector)
|
||||
loop = asyncio.get_event_loop()
|
||||
service = await loop.run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
calendar_id = event.calendar_id or "primary"
|
||||
live_event = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.get(calendarId=calendar_id, eventId=event.event_id)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
if self._is_composio_connector(connector):
|
||||
live_event, error = await self._execute_composio_calendar_tool(
|
||||
connector,
|
||||
"GOOGLECALENDAR_EVENTS_GET",
|
||||
{"calendar_id": calendar_id, "event_id": event.event_id},
|
||||
)
|
||||
if error:
|
||||
raise RuntimeError(error)
|
||||
else:
|
||||
creds = await self._build_credentials(connector)
|
||||
loop = asyncio.get_event_loop()
|
||||
service = await loop.run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
live_event = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.get(calendarId=calendar_id, eventId=event.event_id)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
|
||||
event_dict["summary"] = live_event.get("summary", event_dict["summary"])
|
||||
event_dict["description"] = live_event.get(
|
||||
|
|
@ -376,12 +453,30 @@ class GoogleCalendarToolMetadataService:
|
|||
) -> dict:
|
||||
resolved = await self._resolve_event(search_space_id, user_id, event_ref)
|
||||
if not resolved:
|
||||
live_resolved = await self._resolve_live_event(
|
||||
search_space_id, user_id, event_ref
|
||||
)
|
||||
if not live_resolved:
|
||||
return {
|
||||
"error": (
|
||||
f"Event '{event_ref}' not found in your indexed or live Google Calendar events. "
|
||||
"This could mean: (1) the event doesn't exist, "
|
||||
"(2) the event name is different, or "
|
||||
"(3) the connected calendar account cannot access it."
|
||||
)
|
||||
}
|
||||
|
||||
connector, live_event = live_resolved
|
||||
account = GoogleCalendarAccount.from_connector(connector)
|
||||
acc_dict = account.to_dict()
|
||||
auth_expired = await self._check_account_health(connector.id)
|
||||
acc_dict["auth_expired"] = auth_expired
|
||||
if auth_expired:
|
||||
await self._persist_auth_expired(connector.id)
|
||||
|
||||
return {
|
||||
"error": (
|
||||
f"Event '{event_ref}' not found in your indexed Google Calendar events. "
|
||||
"This could mean: (1) the event doesn't exist, (2) it hasn't been indexed yet, "
|
||||
"or (3) the event name is different."
|
||||
)
|
||||
"account": acc_dict,
|
||||
"event": self._event_dict_from_live_event(live_event),
|
||||
}
|
||||
|
||||
document, connector = resolved
|
||||
|
|
@ -429,3 +524,110 @@ class GoogleCalendarToolMetadataService:
|
|||
if row:
|
||||
return row[0], row[1]
|
||||
return None
|
||||
|
||||
async def _resolve_live_event(
|
||||
self, search_space_id: int, user_id: str, event_ref: str
|
||||
) -> tuple[SearchSourceConnector, dict] | None:
|
||||
result = await self._db_session.execute(
|
||||
select(SearchSourceConnector)
|
||||
.filter(
|
||||
and_(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(CALENDAR_CONNECTOR_TYPES),
|
||||
)
|
||||
)
|
||||
.order_by(SearchSourceConnector.last_indexed_at.desc())
|
||||
)
|
||||
connectors = result.scalars().all()
|
||||
|
||||
for connector in connectors:
|
||||
try:
|
||||
events = await self._search_live_events(connector, event_ref)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to search live calendar events for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
continue
|
||||
|
||||
if not events:
|
||||
continue
|
||||
|
||||
normalized_ref = event_ref.strip().lower()
|
||||
exact_match = next(
|
||||
(
|
||||
event
|
||||
for event in events
|
||||
if event.get("summary", "").strip().lower() == normalized_ref
|
||||
),
|
||||
None,
|
||||
)
|
||||
return connector, exact_match or events[0]
|
||||
|
||||
return None
|
||||
|
||||
async def _search_live_events(
|
||||
self, connector: SearchSourceConnector, event_ref: str
|
||||
) -> list[dict]:
|
||||
if self._is_composio_connector(connector):
|
||||
data, error = await self._execute_composio_calendar_tool(
|
||||
connector,
|
||||
"GOOGLECALENDAR_EVENTS_LIST",
|
||||
{
|
||||
"calendar_id": "primary",
|
||||
"q": event_ref,
|
||||
"max_results": 10,
|
||||
"single_events": True,
|
||||
"order_by": "startTime",
|
||||
},
|
||||
)
|
||||
if error:
|
||||
raise RuntimeError(error)
|
||||
if isinstance(data, dict):
|
||||
return data.get("items") or data.get("events") or []
|
||||
return data if isinstance(data, list) else []
|
||||
|
||||
creds = await self._build_credentials(connector)
|
||||
loop = asyncio.get_event_loop()
|
||||
service = await loop.run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
response = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.list(
|
||||
calendarId="primary",
|
||||
q=event_ref,
|
||||
maxResults=10,
|
||||
singleEvents=True,
|
||||
orderBy="startTime",
|
||||
)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
return response.get("items", [])
|
||||
|
||||
def _event_dict_from_live_event(self, event: dict) -> dict:
|
||||
start_data = event.get("start", {})
|
||||
end_data = event.get("end", {})
|
||||
return {
|
||||
"event_id": event.get("id", ""),
|
||||
"summary": event.get("summary", "No Title"),
|
||||
"start": start_data.get("dateTime", start_data.get("date", "")),
|
||||
"end": end_data.get("dateTime", end_data.get("date", "")),
|
||||
"description": event.get("description", ""),
|
||||
"location": event.get("location", ""),
|
||||
"attendees": [
|
||||
{
|
||||
"email": attendee.get("email", ""),
|
||||
"responseStatus": attendee.get("responseStatus", ""),
|
||||
}
|
||||
for attendee in event.get("attendees", [])
|
||||
],
|
||||
"calendar_id": event.get("calendarId", "primary"),
|
||||
"document_id": None,
|
||||
"indexed_at": None,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from app.db import (
|
|||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
)
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
from app.services.composio_service import ComposioService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -67,6 +67,42 @@ class GoogleDriveToolMetadataService:
|
|||
def __init__(self, db_session: AsyncSession):
|
||||
self._db_session = db_session
|
||||
|
||||
def _is_composio_connector(self, connector: SearchSourceConnector) -> bool:
|
||||
return (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
||||
)
|
||||
|
||||
def _get_composio_connected_account_id(
|
||||
self, connector: SearchSourceConnector
|
||||
) -> str:
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if not cca_id:
|
||||
raise ValueError("Composio connected_account_id not found")
|
||||
return cca_id
|
||||
|
||||
async def _execute_composio_drive_tool(
|
||||
self,
|
||||
connector: SearchSourceConnector,
|
||||
tool_name: str,
|
||||
params: dict,
|
||||
) -> tuple[dict | list | None, str | None]:
|
||||
result = await ComposioService().execute_tool(
|
||||
connected_account_id=self._get_composio_connected_account_id(connector),
|
||||
tool_name=tool_name,
|
||||
params=params,
|
||||
entity_id=f"surfsense_{connector.user_id}",
|
||||
)
|
||||
if not result.get("success"):
|
||||
return None, result.get("error", "Unknown Composio Drive error")
|
||||
data = result.get("data")
|
||||
if isinstance(data, dict):
|
||||
inner = data.get("data", data)
|
||||
if isinstance(inner, dict):
|
||||
return inner.get("response_data", inner), None
|
||||
return inner, None
|
||||
return data, None
|
||||
|
||||
async def get_creation_context(self, search_space_id: int, user_id: str) -> dict:
|
||||
accounts = await self._get_google_drive_accounts(search_space_id, user_id)
|
||||
|
||||
|
|
@ -200,19 +236,21 @@ class GoogleDriveToolMetadataService:
|
|||
if not connector:
|
||||
return True
|
||||
|
||||
pre_built_creds = None
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
||||
):
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
pre_built_creds = build_composio_credentials(cca_id)
|
||||
if self._is_composio_connector(connector):
|
||||
_data, error = await self._execute_composio_drive_tool(
|
||||
connector,
|
||||
"GOOGLEDRIVE_LIST_FILES",
|
||||
{
|
||||
"q": "trashed = false",
|
||||
"page_size": 1,
|
||||
"fields": "files(id)",
|
||||
},
|
||||
)
|
||||
return bool(error)
|
||||
|
||||
client = GoogleDriveClient(
|
||||
session=self._db_session,
|
||||
connector_id=connector_id,
|
||||
credentials=pre_built_creds,
|
||||
)
|
||||
await client.list_files(
|
||||
query="trashed = false", page_size=1, fields="files(id)"
|
||||
|
|
@ -274,19 +312,39 @@ class GoogleDriveToolMetadataService:
|
|||
parent_folders[connector_id] = []
|
||||
continue
|
||||
|
||||
pre_built_creds = None
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
||||
):
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
pre_built_creds = build_composio_credentials(cca_id)
|
||||
if self._is_composio_connector(connector):
|
||||
data, error = await self._execute_composio_drive_tool(
|
||||
connector,
|
||||
"GOOGLEDRIVE_LIST_FILES",
|
||||
{
|
||||
"q": "mimeType = 'application/vnd.google-apps.folder' and trashed = false and 'root' in parents",
|
||||
"fields": "files(id,name)",
|
||||
"page_size": 50,
|
||||
},
|
||||
)
|
||||
if error:
|
||||
logger.warning(
|
||||
"Failed to list folders for connector %s: %s",
|
||||
connector_id,
|
||||
error,
|
||||
)
|
||||
parent_folders[connector_id] = []
|
||||
continue
|
||||
folders = []
|
||||
if isinstance(data, dict):
|
||||
folders = data.get("files", [])
|
||||
elif isinstance(data, list):
|
||||
folders = data
|
||||
parent_folders[connector_id] = [
|
||||
{"folder_id": f["id"], "name": f["name"]}
|
||||
for f in folders
|
||||
if f.get("id") and f.get("name")
|
||||
]
|
||||
continue
|
||||
|
||||
client = GoogleDriveClient(
|
||||
session=self._db_session,
|
||||
connector_id=connector_id,
|
||||
credentials=pre_built_creds,
|
||||
)
|
||||
|
||||
folders, _, error = await client.list_files(
|
||||
|
|
|
|||
|
|
@ -20,6 +20,8 @@ from typing import Any
|
|||
from litellm import Router
|
||||
from litellm.utils import ImageResponse
|
||||
|
||||
from app.services.provider_api_base import resolve_api_base
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Special ID for Auto mode - uses router for load balancing
|
||||
|
|
@ -152,12 +154,12 @@ class ImageGenRouterService:
|
|||
return None
|
||||
|
||||
# Build model string
|
||||
provider = config.get("provider", "").upper()
|
||||
if config.get("custom_provider"):
|
||||
model_string = f"{config['custom_provider']}/{config['model_name']}"
|
||||
provider_prefix = config["custom_provider"]
|
||||
else:
|
||||
provider = config.get("provider", "").upper()
|
||||
provider_prefix = IMAGE_GEN_PROVIDER_MAP.get(provider, provider.lower())
|
||||
model_string = f"{provider_prefix}/{config['model_name']}"
|
||||
model_string = f"{provider_prefix}/{config['model_name']}"
|
||||
|
||||
# Build litellm params
|
||||
litellm_params: dict[str, Any] = {
|
||||
|
|
@ -165,9 +167,16 @@ class ImageGenRouterService:
|
|||
"api_key": config.get("api_key"),
|
||||
}
|
||||
|
||||
# Add optional api_base
|
||||
if config.get("api_base"):
|
||||
litellm_params["api_base"] = config["api_base"]
|
||||
# Resolve ``api_base`` so deployments don't silently inherit
|
||||
# ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE`` and 404 against
|
||||
# the wrong provider (see ``provider_api_base`` docstring).
|
||||
api_base = resolve_api_base(
|
||||
provider=provider,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=config.get("api_base"),
|
||||
)
|
||||
if api_base:
|
||||
litellm_params["api_base"] = api_base
|
||||
|
||||
# Add api_version (required for Azure)
|
||||
if config.get("api_version"):
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ from litellm.exceptions import (
|
|||
BadRequestError as LiteLLMBadRequestError,
|
||||
ContextWindowExceededError,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
|
|
@ -133,42 +134,14 @@ PROVIDER_MAP = {
|
|||
}
|
||||
|
||||
|
||||
# Default ``api_base`` per LiteLLM provider prefix. Used as a safety net when
|
||||
# a global LLM config does *not* specify ``api_base``: without this, LiteLLM
|
||||
# happily picks up provider-agnostic env vars (e.g. ``AZURE_API_BASE``,
|
||||
# ``OPENAI_API_BASE``) and routes, say, an ``openrouter/anthropic/claude-3-haiku``
|
||||
# request to an Azure endpoint, which then 404s with ``Resource not found``.
|
||||
# Only providers with a well-known, stable public base URL are listed here —
|
||||
# self-hosted / BYO-endpoint providers (ollama, custom, bedrock, vertex_ai,
|
||||
# huggingface, databricks, cloudflare, replicate) are intentionally omitted
|
||||
# so their existing config-driven behaviour is preserved.
|
||||
PROVIDER_DEFAULT_API_BASE = {
|
||||
"openrouter": "https://openrouter.ai/api/v1",
|
||||
"groq": "https://api.groq.com/openai/v1",
|
||||
"mistral": "https://api.mistral.ai/v1",
|
||||
"perplexity": "https://api.perplexity.ai",
|
||||
"xai": "https://api.x.ai/v1",
|
||||
"cerebras": "https://api.cerebras.ai/v1",
|
||||
"deepinfra": "https://api.deepinfra.com/v1/openai",
|
||||
"fireworks_ai": "https://api.fireworks.ai/inference/v1",
|
||||
"together_ai": "https://api.together.xyz/v1",
|
||||
"anyscale": "https://api.endpoints.anyscale.com/v1",
|
||||
"cometapi": "https://api.cometapi.com/v1",
|
||||
"sambanova": "https://api.sambanova.ai/v1",
|
||||
}
|
||||
|
||||
|
||||
# Canonical provider → base URL when a config uses a generic ``openai``-style
|
||||
# prefix but the ``provider`` field tells us which API it really is
|
||||
# (e.g. DeepSeek/Alibaba/Moonshot/Zhipu/MiniMax all use ``openai`` compat but
|
||||
# each has its own base URL).
|
||||
PROVIDER_KEY_DEFAULT_API_BASE = {
|
||||
"DEEPSEEK": "https://api.deepseek.com/v1",
|
||||
"ALIBABA_QWEN": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
|
||||
"MOONSHOT": "https://api.moonshot.ai/v1",
|
||||
"ZHIPU": "https://open.bigmodel.cn/api/paas/v4",
|
||||
"MINIMAX": "https://api.minimax.io/v1",
|
||||
}
|
||||
# ``PROVIDER_DEFAULT_API_BASE`` and ``PROVIDER_KEY_DEFAULT_API_BASE`` were
|
||||
# hoisted to ``app.services.provider_api_base`` so vision and image-gen
|
||||
# call sites can share the exact same defense (OpenRouter / Groq / etc.
|
||||
# 404-ing against an inherited Azure endpoint). Re-exported here for
|
||||
# backward compatibility with any external import.
|
||||
from app.services.provider_api_base import ( # noqa: E402
|
||||
resolve_api_base,
|
||||
)
|
||||
|
||||
|
||||
class LLMRouterService:
|
||||
|
|
@ -207,6 +180,12 @@ class LLMRouterService:
|
|||
"""
|
||||
Initialize the router with global LLM configurations.
|
||||
|
||||
Configs with ``router_pool_eligible=False`` are skipped so that
|
||||
dynamic OpenRouter entries stay out of the shared router pool used
|
||||
by title-gen / sub-agent ``model="auto"`` flows. Those dynamic
|
||||
entries are still available for user-facing Auto-mode thread pinning
|
||||
via ``auto_model_pin_service``.
|
||||
|
||||
Args:
|
||||
global_configs: List of global LLM config dictionaries from YAML
|
||||
router_settings: Optional router settings (routing_strategy, num_retries, etc.)
|
||||
|
|
@ -220,6 +199,8 @@ class LLMRouterService:
|
|||
model_list = []
|
||||
premium_models: set[str] = set()
|
||||
for config in global_configs:
|
||||
if config.get("router_pool_eligible") is False:
|
||||
continue
|
||||
deployment = cls._config_to_deployment(config)
|
||||
if deployment:
|
||||
model_list.append(deployment)
|
||||
|
|
@ -308,10 +289,45 @@ class LLMRouterService:
|
|||
logger.error(f"Failed to initialize LLM Router: {e}")
|
||||
instance._router = None
|
||||
|
||||
@classmethod
|
||||
def rebuild(
|
||||
cls,
|
||||
global_configs: list[dict],
|
||||
router_settings: dict | None = None,
|
||||
) -> None:
|
||||
"""Reset the router and re-run ``initialize`` with fresh configs.
|
||||
|
||||
``initialize`` short-circuits once it has run to avoid re-creating the
|
||||
LiteLLM Router on every request; ``rebuild`` deliberately clears
|
||||
``_initialized`` so a caller (e.g. background OpenRouter refresh)
|
||||
can force the pool to be rebuilt after catalogue changes.
|
||||
"""
|
||||
instance = cls.get_instance()
|
||||
instance._initialized = False
|
||||
instance._router = None
|
||||
instance._model_list = []
|
||||
instance._premium_model_strings = set()
|
||||
cls.initialize(global_configs, router_settings)
|
||||
|
||||
@classmethod
|
||||
def is_premium_model(cls, model_string: str) -> bool:
|
||||
"""Return True if *model_string* (as reported by LiteLLM) belongs to a
|
||||
premium-tier deployment in the router pool."""
|
||||
"""Return True if *model_string* belongs to a premium-tier deployment
|
||||
in the LiteLLM router pool.
|
||||
|
||||
Scope: only covers configs with ``router_pool_eligible`` truthy. That
|
||||
includes static YAML premium configs AND dynamic OpenRouter *premium*
|
||||
entries (which opt in at generation time). Dynamic OpenRouter *free*
|
||||
entries are deliberately kept out of the router pool — OpenRouter
|
||||
enforces free-tier limits globally per account, so per-deployment
|
||||
router accounting can't represent them correctly — and therefore
|
||||
return ``False`` here, which matches their ``billing_tier="free"``
|
||||
(no premium quota).
|
||||
|
||||
For per-request premium checks on an arbitrary config (static or
|
||||
dynamic, pool or non-pool), read ``agent_config.is_premium`` instead;
|
||||
that reflects the per-config ``billing_tier`` directly and is what
|
||||
user-facing Auto-mode thread pinning uses to bill correctly.
|
||||
"""
|
||||
instance = cls.get_instance()
|
||||
return model_string in instance._premium_model_strings
|
||||
|
||||
|
|
@ -422,14 +438,14 @@ class LLMRouterService:
|
|||
# Resolve ``api_base``. Config value wins; otherwise apply a
|
||||
# provider-aware default so the deployment does not silently
|
||||
# inherit unrelated env vars (e.g. ``AZURE_API_BASE``) and route
|
||||
# requests to the wrong endpoint. See ``PROVIDER_DEFAULT_API_BASE``
|
||||
# requests to the wrong endpoint. See ``provider_api_base``
|
||||
# docstring for the motivating bug (OpenRouter models 404-ing
|
||||
# against an Azure endpoint).
|
||||
api_base = config.get("api_base")
|
||||
if not api_base:
|
||||
api_base = PROVIDER_KEY_DEFAULT_API_BASE.get(provider)
|
||||
if not api_base:
|
||||
api_base = PROVIDER_DEFAULT_API_BASE.get(provider_prefix)
|
||||
api_base = resolve_api_base(
|
||||
provider=provider,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=config.get("api_base"),
|
||||
)
|
||||
if api_base:
|
||||
litellm_params["api_base"] = api_base
|
||||
|
||||
|
|
@ -573,6 +589,11 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
# Public attributes that Pydantic will manage
|
||||
model: str = "auto"
|
||||
streaming: bool = True
|
||||
# Static kwargs that flow through to ``litellm.completion(...)`` on every
|
||||
# invocation (e.g. ``cache_control_injection_points`` set by
|
||||
# ``apply_litellm_prompt_caching``). Per-call ``**kwargs`` from
|
||||
# ``invoke()`` still take precedence — see ``_generate``/``_astream``.
|
||||
model_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# Bound tools and tool choice for tool calling
|
||||
_bound_tools: list[dict] | None = None
|
||||
|
|
@ -898,13 +919,16 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
logger.warning(f"Failed to convert tool {tool}: {e}")
|
||||
continue
|
||||
|
||||
# Create a new instance with tools bound
|
||||
# Create a new instance with tools bound. Carry through ``model_kwargs``
|
||||
# so static settings (e.g. cache_control_injection_points) survive the
|
||||
# bind_tools rebuild.
|
||||
return ChatLiteLLMRouter(
|
||||
router=self._router,
|
||||
bound_tools=formatted_tools if formatted_tools else None,
|
||||
tool_choice=tool_choice,
|
||||
model=self.model,
|
||||
streaming=self.streaming,
|
||||
model_kwargs=dict(self.model_kwargs),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
|
@ -929,8 +953,10 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
formatted_messages = self._convert_messages(messages)
|
||||
formatted_messages = self._trim_messages_to_fit_context(formatted_messages)
|
||||
|
||||
# Add tools if bound
|
||||
call_kwargs = {**kwargs}
|
||||
# Merge static model_kwargs (e.g. cache_control_injection_points) under
|
||||
# per-call kwargs so callers can still override per invocation. Then add
|
||||
# bound tools.
|
||||
call_kwargs = {**self.model_kwargs, **kwargs}
|
||||
if self._bound_tools:
|
||||
call_kwargs["tools"] = self._bound_tools
|
||||
if self._tool_choice is not None:
|
||||
|
|
@ -997,8 +1023,10 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
formatted_messages = self._convert_messages(messages)
|
||||
formatted_messages = self._trim_messages_to_fit_context(formatted_messages)
|
||||
|
||||
# Add tools if bound
|
||||
call_kwargs = {**kwargs}
|
||||
# Merge static model_kwargs (e.g. cache_control_injection_points) under
|
||||
# per-call kwargs so callers can still override per invocation. Then add
|
||||
# bound tools.
|
||||
call_kwargs = {**self.model_kwargs, **kwargs}
|
||||
if self._bound_tools:
|
||||
call_kwargs["tools"] = self._bound_tools
|
||||
if self._tool_choice is not None:
|
||||
|
|
@ -1060,8 +1088,10 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
formatted_messages = self._convert_messages(messages)
|
||||
formatted_messages = self._trim_messages_to_fit_context(formatted_messages)
|
||||
|
||||
# Add tools if bound
|
||||
call_kwargs = {**kwargs}
|
||||
# Merge static model_kwargs (e.g. cache_control_injection_points) under
|
||||
# per-call kwargs so callers can still override per invocation. Then add
|
||||
# bound tools.
|
||||
call_kwargs = {**self.model_kwargs, **kwargs}
|
||||
if self._bound_tools:
|
||||
call_kwargs["tools"] = self._bound_tools
|
||||
if self._tool_choice is not None:
|
||||
|
|
@ -1110,8 +1140,10 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
formatted_messages = self._convert_messages(messages)
|
||||
formatted_messages = self._trim_messages_to_fit_context(formatted_messages)
|
||||
|
||||
# Add tools if bound
|
||||
call_kwargs = {**kwargs}
|
||||
# Merge static model_kwargs (e.g. cache_control_injection_points) under
|
||||
# per-call kwargs so callers can still override per invocation. Then add
|
||||
# bound tools.
|
||||
call_kwargs = {**self.model_kwargs, **kwargs}
|
||||
if self._bound_tools:
|
||||
call_kwargs["tools"] = self._bound_tools
|
||||
if self._tool_choice is not None:
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from app.services.llm_router_service import (
|
|||
get_auto_mode_llm,
|
||||
is_auto_mode,
|
||||
)
|
||||
from app.services.provider_api_base import resolve_api_base
|
||||
from app.services.token_tracking_service import token_tracker
|
||||
|
||||
# Configure litellm to automatically drop unsupported parameters
|
||||
|
|
@ -496,8 +497,14 @@ async def get_vision_llm(
|
|||
- Auto mode (ID 0): VisionLLMRouterService
|
||||
- Global (negative ID): YAML configs
|
||||
- DB (positive ID): VisionLLMConfig table
|
||||
|
||||
Premium global configs are wrapped in :class:`QuotaCheckedVisionLLM`
|
||||
so each ``ainvoke`` debits the search-space owner's premium credit
|
||||
pool. User-owned BYOK configs and free global configs are returned
|
||||
unwrapped — they don't consume premium credit (issue M).
|
||||
"""
|
||||
from app.db import VisionLLMConfig
|
||||
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
|
||||
from app.services.vision_llm_router_service import (
|
||||
VISION_PROVIDER_MAP,
|
||||
VisionLLMRouterService,
|
||||
|
|
@ -519,6 +526,8 @@ async def get_vision_llm(
|
|||
logger.error(f"No vision LLM configured for search space {search_space_id}")
|
||||
return None
|
||||
|
||||
owner_user_id = search_space.user_id
|
||||
|
||||
if is_vision_auto_mode(config_id):
|
||||
if not VisionLLMRouterService.is_initialized():
|
||||
logger.error(
|
||||
|
|
@ -526,6 +535,13 @@ async def get_vision_llm(
|
|||
)
|
||||
return None
|
||||
try:
|
||||
# Auto mode is currently treated as free at the wrapper
|
||||
# level — the underlying router can dispatch to either
|
||||
# premium or free YAML configs but routing decisions are
|
||||
# opaque. If/when we want to bill Auto-routed vision
|
||||
# calls we'd need to thread the resolved deployment's
|
||||
# billing_tier back from the router. For now we keep
|
||||
# parity with chat Auto, which also doesn't pre-classify.
|
||||
return ChatLiteLLMRouter(
|
||||
router=VisionLLMRouterService.get_router(),
|
||||
streaming=True,
|
||||
|
|
@ -541,29 +557,46 @@ async def get_vision_llm(
|
|||
return None
|
||||
|
||||
if global_cfg.get("custom_provider"):
|
||||
model_string = (
|
||||
f"{global_cfg['custom_provider']}/{global_cfg['model_name']}"
|
||||
)
|
||||
provider_prefix = global_cfg["custom_provider"]
|
||||
model_string = f"{provider_prefix}/{global_cfg['model_name']}"
|
||||
else:
|
||||
prefix = VISION_PROVIDER_MAP.get(
|
||||
provider_prefix = VISION_PROVIDER_MAP.get(
|
||||
global_cfg["provider"].upper(),
|
||||
global_cfg["provider"].lower(),
|
||||
)
|
||||
model_string = f"{prefix}/{global_cfg['model_name']}"
|
||||
model_string = f"{provider_prefix}/{global_cfg['model_name']}"
|
||||
|
||||
litellm_kwargs = {
|
||||
"model": model_string,
|
||||
"api_key": global_cfg["api_key"],
|
||||
}
|
||||
if global_cfg.get("api_base"):
|
||||
litellm_kwargs["api_base"] = global_cfg["api_base"]
|
||||
api_base = resolve_api_base(
|
||||
provider=global_cfg.get("provider"),
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=global_cfg.get("api_base"),
|
||||
)
|
||||
if api_base:
|
||||
litellm_kwargs["api_base"] = api_base
|
||||
if global_cfg.get("litellm_params"):
|
||||
litellm_kwargs.update(global_cfg["litellm_params"])
|
||||
|
||||
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
||||
|
||||
return SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
inner_llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
billing_tier = str(global_cfg.get("billing_tier", "free")).lower()
|
||||
if billing_tier == "premium":
|
||||
return QuotaCheckedVisionLLM(
|
||||
inner_llm,
|
||||
user_id=owner_user_id,
|
||||
search_space_id=search_space_id,
|
||||
billing_tier=billing_tier,
|
||||
base_model=model_string,
|
||||
quota_reserve_tokens=global_cfg.get("quota_reserve_tokens"),
|
||||
)
|
||||
return inner_llm
|
||||
|
||||
# User-owned (positive ID) BYOK configs — always free.
|
||||
result = await session.execute(
|
||||
select(VisionLLMConfig).where(
|
||||
VisionLLMConfig.id == config_id,
|
||||
|
|
@ -578,20 +611,26 @@ async def get_vision_llm(
|
|||
return None
|
||||
|
||||
if vision_cfg.custom_provider:
|
||||
model_string = f"{vision_cfg.custom_provider}/{vision_cfg.model_name}"
|
||||
provider_prefix = vision_cfg.custom_provider
|
||||
model_string = f"{provider_prefix}/{vision_cfg.model_name}"
|
||||
else:
|
||||
prefix = VISION_PROVIDER_MAP.get(
|
||||
provider_prefix = VISION_PROVIDER_MAP.get(
|
||||
vision_cfg.provider.value.upper(),
|
||||
vision_cfg.provider.value.lower(),
|
||||
)
|
||||
model_string = f"{prefix}/{vision_cfg.model_name}"
|
||||
model_string = f"{provider_prefix}/{vision_cfg.model_name}"
|
||||
|
||||
litellm_kwargs = {
|
||||
"model": model_string,
|
||||
"api_key": vision_cfg.api_key,
|
||||
}
|
||||
if vision_cfg.api_base:
|
||||
litellm_kwargs["api_base"] = vision_cfg.api_base
|
||||
api_base = resolve_api_base(
|
||||
provider=vision_cfg.provider.value,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=vision_cfg.api_base,
|
||||
)
|
||||
if api_base:
|
||||
litellm_kwargs["api_base"] = api_base
|
||||
if vision_cfg.litellm_params:
|
||||
litellm_kwargs.update(vision_cfg.litellm_params)
|
||||
|
||||
|
|
|
|||
|
|
@ -565,20 +565,31 @@ class VercelStreamingService:
|
|||
# Error Part
|
||||
# =========================================================================
|
||||
|
||||
def format_error(self, error_text: str) -> str:
|
||||
def format_error(
|
||||
self,
|
||||
error_text: str,
|
||||
error_code: str | None = None,
|
||||
extra: dict[str, object] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Format an error message.
|
||||
|
||||
Args:
|
||||
error_text: The error message text
|
||||
error_code: Optional machine-readable error code for frontend branching
|
||||
|
||||
Returns:
|
||||
str: SSE formatted error part
|
||||
|
||||
Example output:
|
||||
data: {"type":"error","errorText":"Something went wrong"}
|
||||
data: {"type":"error","errorText":"Something went wrong","errorCode":"SOME_CODE"}
|
||||
"""
|
||||
return self._format_sse({"type": "error", "errorText": error_text})
|
||||
payload: dict[str, object] = {"type": "error", "errorText": error_text}
|
||||
if error_code:
|
||||
payload["errorCode"] = error_code
|
||||
if extra:
|
||||
payload.update(extra)
|
||||
return self._format_sse(payload)
|
||||
|
||||
# =========================================================================
|
||||
# Tool Parts
|
||||
|
|
|
|||
|
|
@ -11,20 +11,81 @@ this service only manages the catalogue, not the inference path.
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from app.services.quality_score import (
|
||||
_HEALTH_BLEND_WEIGHT,
|
||||
_HEALTH_ENRICH_CONCURRENCY,
|
||||
_HEALTH_ENRICH_TOP_N_FREE,
|
||||
_HEALTH_ENRICH_TOP_N_PREMIUM,
|
||||
_HEALTH_FAIL_RATIO_FALLBACK,
|
||||
_HEALTH_FETCH_TIMEOUT_SEC,
|
||||
aggregate_health,
|
||||
static_score_or,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OPENROUTER_API_URL = "https://openrouter.ai/api/v1/models"
|
||||
OPENROUTER_ENDPOINTS_URL_TEMPLATE = (
|
||||
"https://openrouter.ai/api/v1/models/{model_id}/endpoints"
|
||||
)
|
||||
|
||||
# Sentinel value stored on each generated config so we can distinguish
|
||||
# dynamic OpenRouter entries from hand-written YAML entries during refresh.
|
||||
_OPENROUTER_DYNAMIC_MARKER = "__openrouter_dynamic__"
|
||||
|
||||
# Width of the hash space used by ``_stable_config_id``. 9_000_000 provides
|
||||
# enough headroom to avoid frequent collisions for OpenRouter's catalogue
|
||||
# (~300 models) while keeping IDs comfortably within Postgres INTEGER range.
|
||||
_STABLE_ID_HASH_WIDTH = 9_000_000
|
||||
|
||||
|
||||
def _stable_config_id(model_id: str, offset: int, taken: set[int]) -> int:
|
||||
"""Derive a deterministic negative config ID from ``model_id``.
|
||||
|
||||
The same ``model_id`` always hashes to the same base value so thread pins
|
||||
survive catalogue churn (models appearing/disappearing/reordering between
|
||||
refreshes). On collision we decrement until we find an unused slot; this
|
||||
keeps the mapping stable for the first config that claimed a slot and
|
||||
only shifts collisions, which is much less disruptive than the legacy
|
||||
index-based scheme that reshuffled every ID when the catalogue changed.
|
||||
"""
|
||||
digest = hashlib.blake2b(model_id.encode("utf-8"), digest_size=6).digest()
|
||||
base = offset - (int.from_bytes(digest, "big") % _STABLE_ID_HASH_WIDTH)
|
||||
cid = base
|
||||
while cid in taken:
|
||||
cid -= 1
|
||||
taken.add(cid)
|
||||
return cid
|
||||
|
||||
|
||||
def _openrouter_tier(model: dict) -> str:
|
||||
"""Classify an OpenRouter model as ``"free"`` or ``"premium"``.
|
||||
|
||||
Per OpenRouter's API contract, a model is free if:
|
||||
- Its id ends with ``:free`` (OpenRouter's own free-variant convention), or
|
||||
- Both ``pricing.prompt`` and ``pricing.completion`` are zero strings.
|
||||
|
||||
Anything else (missing pricing, non-zero pricing) falls through to
|
||||
``"premium"`` so we never under-charge users. This derivation runs off the
|
||||
already-cached /api/v1/models payload, so it adds no network cost.
|
||||
"""
|
||||
if model.get("id", "").endswith(":free"):
|
||||
return "free"
|
||||
pricing = model.get("pricing") or {}
|
||||
prompt = str(pricing.get("prompt", "")).strip()
|
||||
completion = str(pricing.get("completion", "")).strip()
|
||||
if prompt == "0" and completion == "0":
|
||||
return "free"
|
||||
return "premium"
|
||||
|
||||
|
||||
def _is_text_output_model(model: dict) -> bool:
|
||||
"""Return True if the model produces text output only (skip image/audio generators)."""
|
||||
|
|
@ -32,6 +93,53 @@ def _is_text_output_model(model: dict) -> bool:
|
|||
return output_mods == ["text"]
|
||||
|
||||
|
||||
def _is_image_output_model(model: dict) -> bool:
|
||||
"""Return True if the model can produce image output.
|
||||
|
||||
OpenRouter's ``architecture.output_modalities`` is a list (e.g.
|
||||
``["image"]`` for pure image generators, ``["text", "image"]`` for
|
||||
multi-modal generators that also emit captions). We accept any model
|
||||
that can output images; the call site decides whether to use the
|
||||
image-generation API or chat completion.
|
||||
"""
|
||||
output_mods = model.get("architecture", {}).get("output_modalities", []) or []
|
||||
return "image" in output_mods
|
||||
|
||||
|
||||
def _is_vision_input_model(model: dict) -> bool:
|
||||
"""Return True if the model can ingest an image AND emit text.
|
||||
|
||||
OpenRouter's ``architecture.input_modalities`` lists what the model
|
||||
accepts; ``output_modalities`` lists what it produces. A vision LLM
|
||||
is a model that takes images in and produces text out — i.e. it can
|
||||
answer questions about a screenshot or extract content from an
|
||||
image. Pure image-to-image models (e.g. style transfer) and
|
||||
text-only models are excluded.
|
||||
"""
|
||||
arch = model.get("architecture", {}) or {}
|
||||
input_mods = arch.get("input_modalities", []) or []
|
||||
output_mods = arch.get("output_modalities", []) or []
|
||||
return "image" in input_mods and "text" in output_mods
|
||||
|
||||
|
||||
def _supports_image_input(model: dict) -> bool:
|
||||
"""Return True if the model accepts ``image`` in its input modalities.
|
||||
|
||||
Differs from :func:`_is_vision_input_model` in that it does NOT
|
||||
require text output — chat-tab models always emit text already (the
|
||||
chat catalog filters by ``_is_text_output_model``), so the only
|
||||
extra capability we need to track per chat config is whether the
|
||||
model can ingest user-attached images. The chat selector and the
|
||||
streaming task both key off this flag to prevent hitting an
|
||||
OpenRouter 404 ``"No endpoints found that support image input"``
|
||||
when the user uploads an image and selects a text-only model
|
||||
(DeepSeek V3, Llama 3.x base, etc.).
|
||||
"""
|
||||
arch = model.get("architecture", {}) or {}
|
||||
input_mods = arch.get("input_modalities", []) or []
|
||||
return "image" in input_mods
|
||||
|
||||
|
||||
def _supports_tool_calling(model: dict) -> bool:
|
||||
"""Return True if the model supports function/tool calling."""
|
||||
supported = model.get("supported_parameters") or []
|
||||
|
|
@ -56,6 +164,11 @@ _EXCLUDED_MODEL_IDS: set[str] = {
|
|||
# Deep-research models reject standard params (temperature, etc.)
|
||||
"openai/o3-deep-research",
|
||||
"openai/o4-mini-deep-research",
|
||||
# OpenRouter's own meta-router over free models. We already enumerate every
|
||||
# concrete ``:free`` model into GLOBAL_LLM_CONFIGS and Auto-mode thread
|
||||
# pinning handles churn via the repair path, so exposing an additional
|
||||
# indirection layer would only duplicate the capability with an opaque slug.
|
||||
"openrouter/free",
|
||||
}
|
||||
|
||||
_EXCLUDED_MODEL_SUFFIXES: tuple[str, ...] = ("-deep-research",)
|
||||
|
|
@ -109,24 +222,71 @@ async def _fetch_models_async() -> list[dict] | None:
|
|||
return None
|
||||
|
||||
|
||||
def _extract_raw_pricing(raw_models: list[dict]) -> dict[str, dict[str, str]]:
|
||||
"""Return a ``{model_id: {"prompt": str, "completion": str}}`` map.
|
||||
|
||||
Pricing values are kept as the raw OpenRouter strings (e.g.
|
||||
``"0.000003"``); ``pricing_registration`` converts them to floats
|
||||
when registering with LiteLLM. Models with missing or malformed
|
||||
pricing are simply omitted — operator-side risk if any of those are
|
||||
premium.
|
||||
"""
|
||||
pricing: dict[str, dict[str, str]] = {}
|
||||
for model in raw_models:
|
||||
model_id = str(model.get("id") or "").strip()
|
||||
if not model_id:
|
||||
continue
|
||||
p = model.get("pricing") or {}
|
||||
prompt = p.get("prompt")
|
||||
completion = p.get("completion")
|
||||
if prompt is None and completion is None:
|
||||
continue
|
||||
pricing[model_id] = {
|
||||
"prompt": str(prompt) if prompt is not None else "",
|
||||
"completion": str(completion) if completion is not None else "",
|
||||
}
|
||||
return pricing
|
||||
|
||||
|
||||
def _generate_configs(
|
||||
raw_models: list[dict],
|
||||
settings: dict[str, Any],
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Convert raw OpenRouter model entries into global LLM config dicts.
|
||||
"""Convert raw OpenRouter model entries into global LLM config dicts.
|
||||
|
||||
Models are sorted by ID for deterministic, stable ID assignment across
|
||||
restarts and refreshes.
|
||||
Tier (``billing_tier``) is derived per-model from OpenRouter's own API
|
||||
signals via ``_openrouter_tier`` — there is no longer a uniform YAML
|
||||
override. Config IDs are derived via ``_stable_config_id`` so they
|
||||
survive catalogue churn across refreshes.
|
||||
|
||||
Router-pool membership is tier-aware:
|
||||
|
||||
- Premium OR models join the LiteLLM router pool (``router_pool_eligible=True``)
|
||||
so sub-agent ``model="auto"`` flows benefit from load balancing and
|
||||
failover across the curated YAML configs and the OR premium passthrough.
|
||||
- Free OR models stay excluded (``router_pool_eligible=False``). LiteLLM
|
||||
Router tracks rate limits per deployment, but OpenRouter enforces a
|
||||
single global free-tier quota (~20 RPM + 50-1000 daily requests
|
||||
account-wide across every ``:free`` model), so rotating across many
|
||||
free deployments would only burn the shared bucket faster. Free OR
|
||||
models remain fully available for user-facing Auto-mode thread pinning
|
||||
via ``auto_model_pin_service``.
|
||||
|
||||
OpenRouter's own ``openrouter/free`` meta-router is filtered out upstream
|
||||
via ``_EXCLUDED_MODEL_IDS``; we don't expose a redundant auto-select layer
|
||||
because our own Auto (Fastest) pin + 24 h refresh + repair logic already
|
||||
cover the catalogue-churn case.
|
||||
"""
|
||||
id_offset: int = settings.get("id_offset", -10000)
|
||||
api_key: str = settings.get("api_key", "")
|
||||
billing_tier: str = settings.get("billing_tier", "premium")
|
||||
anonymous_enabled: bool = settings.get("anonymous_enabled", False)
|
||||
seo_enabled: bool = settings.get("seo_enabled", False)
|
||||
quota_reserve_tokens: int = settings.get("quota_reserve_tokens", 4000)
|
||||
rpm: int = settings.get("rpm", 200)
|
||||
tpm: int = settings.get("tpm", 1000000)
|
||||
tpm: int = settings.get("tpm", 1_000_000)
|
||||
free_rpm: int = settings.get("free_rpm", 20)
|
||||
free_tpm: int = settings.get("free_tpm", 100_000)
|
||||
anon_paid: bool = settings.get("anonymous_enabled_paid", False)
|
||||
anon_free: bool = settings.get("anonymous_enabled_free", False)
|
||||
litellm_params: dict = settings.get("litellm_params") or {}
|
||||
system_instructions: str = settings.get("system_instructions", "")
|
||||
use_default: bool = settings.get("use_default_system_instructions", True)
|
||||
|
|
@ -142,19 +302,24 @@ def _generate_configs(
|
|||
and _is_allowed_model(m)
|
||||
and "/" in m.get("id", "")
|
||||
]
|
||||
text_models.sort(key=lambda m: m["id"])
|
||||
|
||||
configs: list[dict] = []
|
||||
for idx, model in enumerate(text_models):
|
||||
taken: set[int] = set()
|
||||
now_ts = int(time.time())
|
||||
|
||||
for model in text_models:
|
||||
model_id: str = model["id"]
|
||||
name: str = model.get("name", model_id)
|
||||
tier = _openrouter_tier(model)
|
||||
|
||||
static_q = static_score_or(model, now_ts=now_ts)
|
||||
|
||||
cfg: dict[str, Any] = {
|
||||
"id": id_offset - idx,
|
||||
"id": _stable_config_id(model_id, id_offset, taken),
|
||||
"name": name,
|
||||
"description": f"{name} via OpenRouter",
|
||||
"billing_tier": billing_tier,
|
||||
"anonymous_enabled": anonymous_enabled,
|
||||
"billing_tier": tier,
|
||||
"anonymous_enabled": anon_free if tier == "free" else anon_paid,
|
||||
"seo_enabled": seo_enabled,
|
||||
"seo_slug": None,
|
||||
"quota_reserve_tokens": quota_reserve_tokens,
|
||||
|
|
@ -162,12 +327,199 @@ def _generate_configs(
|
|||
"model_name": model_id,
|
||||
"api_key": api_key,
|
||||
"api_base": "",
|
||||
"rpm": rpm,
|
||||
"tpm": tpm,
|
||||
"rpm": free_rpm if tier == "free" else rpm,
|
||||
"tpm": free_tpm if tier == "free" else tpm,
|
||||
"litellm_params": dict(litellm_params),
|
||||
"system_instructions": system_instructions,
|
||||
"use_default_system_instructions": use_default,
|
||||
"citations_enabled": citations_enabled,
|
||||
# Premium OR deployments join the LiteLLM router pool so sub-agent
|
||||
# model="auto" flows can load-balance / fail over across them.
|
||||
# Free OR deployments stay out: OpenRouter's free tier is a single
|
||||
# account-wide quota, so per-deployment routing can't spread load
|
||||
# there — it just drains the shared bucket faster.
|
||||
"router_pool_eligible": tier == "premium",
|
||||
# Capability flag derived from ``architecture.input_modalities``.
|
||||
# Read by the new-chat selector to dim image-incompatible models
|
||||
# when the user has pending image attachments, and by
|
||||
# ``stream_new_chat`` as a fail-fast safety net before the
|
||||
# OpenRouter request would otherwise 404 with
|
||||
# ``"No endpoints found that support image input"``.
|
||||
"supports_image_input": _supports_image_input(model),
|
||||
_OPENROUTER_DYNAMIC_MARKER: True,
|
||||
# Auto (Fastest) ranking metadata. ``quality_score`` is initialised
|
||||
# to the static score and gets re-blended with health on the next
|
||||
# ``_enrich_health`` pass (synchronous on refresh, deferred on cold
|
||||
# start so startup latency is unchanged).
|
||||
"auto_pin_tier": "B" if tier == "premium" else "C",
|
||||
"quality_score_static": static_q,
|
||||
"quality_score_health": None,
|
||||
"quality_score": static_q,
|
||||
"health_gated": False,
|
||||
}
|
||||
configs.append(cfg)
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
# ID-offset bands used to keep dynamic OpenRouter configs in their own
|
||||
# namespace per surface. Image / vision get separate bands so a single
|
||||
# Postgres-INTEGER cfg ID is unambiguous about which selector it belongs to.
|
||||
_OPENROUTER_IMAGE_ID_OFFSET_DEFAULT = -20000
|
||||
_OPENROUTER_VISION_ID_OFFSET_DEFAULT = -30000
|
||||
|
||||
|
||||
def _generate_image_gen_configs(
|
||||
raw_models: list[dict], settings: dict[str, Any]
|
||||
) -> list[dict]:
|
||||
"""Convert OpenRouter image-generation models into global image-gen
|
||||
config dicts (matches the YAML shape consumed by ``image_generation_routes``).
|
||||
|
||||
Filter:
|
||||
- architecture.output_modalities contains "image"
|
||||
- compatible provider (excluded slugs blocked)
|
||||
- allowed model id (excluded list blocked)
|
||||
|
||||
Notably we *drop* the chat-only filters (``_supports_tool_calling`` and
|
||||
``_has_sufficient_context``) because tool calls and context windows are
|
||||
irrelevant for the ``aimage_generation`` API. ``billing_tier`` is
|
||||
derived per model the same way as chat (``_openrouter_tier``).
|
||||
|
||||
Cost is intentionally *not* registered with LiteLLM at startup
|
||||
(``pricing_registration`` skips image gen): OpenRouter image-gen
|
||||
models are not in LiteLLM's native cost map and OpenRouter populates
|
||||
``response_cost`` directly from the response header. A defensive
|
||||
branch in ``_extract_cost_usd`` handles the rare case where
|
||||
``usage.cost`` is missing — see ``token_tracking_service``.
|
||||
"""
|
||||
id_offset: int = int(
|
||||
settings.get("image_id_offset") or _OPENROUTER_IMAGE_ID_OFFSET_DEFAULT
|
||||
)
|
||||
api_key: str = settings.get("api_key", "")
|
||||
rpm: int = settings.get("rpm", 200)
|
||||
free_rpm: int = settings.get("free_rpm", 20)
|
||||
litellm_params: dict = settings.get("litellm_params") or {}
|
||||
|
||||
image_models = [
|
||||
m
|
||||
for m in raw_models
|
||||
if _is_image_output_model(m)
|
||||
and _is_compatible_provider(m)
|
||||
and _is_allowed_model(m)
|
||||
and "/" in m.get("id", "")
|
||||
]
|
||||
|
||||
configs: list[dict] = []
|
||||
taken: set[int] = set()
|
||||
for model in image_models:
|
||||
model_id: str = model["id"]
|
||||
name: str = model.get("name", model_id)
|
||||
tier = _openrouter_tier(model)
|
||||
|
||||
cfg: dict[str, Any] = {
|
||||
"id": _stable_config_id(model_id, id_offset, taken),
|
||||
"name": name,
|
||||
"description": f"{name} via OpenRouter (image generation)",
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": model_id,
|
||||
"api_key": api_key,
|
||||
# Pin to OpenRouter's public base URL so a downstream call site
|
||||
# that forgets ``resolve_api_base`` still doesn't inherit
|
||||
# ``AZURE_OPENAI_ENDPOINT`` and 404 on
|
||||
# ``image_generation/transformation`` (defense-in-depth, see
|
||||
# ``provider_api_base`` docstring).
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
"api_version": None,
|
||||
"rpm": free_rpm if tier == "free" else rpm,
|
||||
"litellm_params": dict(litellm_params),
|
||||
"billing_tier": tier,
|
||||
_OPENROUTER_DYNAMIC_MARKER: True,
|
||||
}
|
||||
configs.append(cfg)
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
def _generate_vision_llm_configs(
|
||||
raw_models: list[dict], settings: dict[str, Any]
|
||||
) -> list[dict]:
|
||||
"""Convert OpenRouter vision-capable LLMs into global vision-LLM config
|
||||
dicts (matches the YAML shape consumed by ``vision_llm_routes``).
|
||||
|
||||
Filter:
|
||||
- architecture.input_modalities contains "image"
|
||||
- architecture.output_modalities contains "text"
|
||||
- compatible provider (excluded slugs blocked)
|
||||
- allowed model id (excluded list blocked)
|
||||
|
||||
Vision-LLM is invoked from the indexer (image extraction during
|
||||
document upload) via ``langchain_litellm.ChatLiteLLM.ainvoke``, so
|
||||
the chat-only ``_supports_tool_calling`` and ``_has_sufficient_context``
|
||||
filters do not apply: a small-context vision model that doesn't
|
||||
advertise tool-calling is still perfectly viable for "describe this
|
||||
image" prompts.
|
||||
"""
|
||||
id_offset: int = int(
|
||||
settings.get("vision_id_offset") or _OPENROUTER_VISION_ID_OFFSET_DEFAULT
|
||||
)
|
||||
api_key: str = settings.get("api_key", "")
|
||||
rpm: int = settings.get("rpm", 200)
|
||||
tpm: int = settings.get("tpm", 1_000_000)
|
||||
free_rpm: int = settings.get("free_rpm", 20)
|
||||
free_tpm: int = settings.get("free_tpm", 100_000)
|
||||
quota_reserve_tokens: int = settings.get("quota_reserve_tokens", 4000)
|
||||
litellm_params: dict = settings.get("litellm_params") or {}
|
||||
|
||||
vision_models = [
|
||||
m
|
||||
for m in raw_models
|
||||
if _is_vision_input_model(m)
|
||||
and _is_compatible_provider(m)
|
||||
and _is_allowed_model(m)
|
||||
and "/" in m.get("id", "")
|
||||
]
|
||||
|
||||
configs: list[dict] = []
|
||||
taken: set[int] = set()
|
||||
for model in vision_models:
|
||||
model_id: str = model["id"]
|
||||
name: str = model.get("name", model_id)
|
||||
tier = _openrouter_tier(model)
|
||||
pricing = model.get("pricing") or {}
|
||||
|
||||
# Capture per-token prices so ``pricing_registration`` can
|
||||
# register them with LiteLLM at startup (and so the cost
|
||||
# estimator in ``estimate_call_reserve_micros`` can resolve
|
||||
# them at reserve time).
|
||||
try:
|
||||
input_cost = float(pricing.get("prompt", 0) or 0)
|
||||
except (TypeError, ValueError):
|
||||
input_cost = 0.0
|
||||
try:
|
||||
output_cost = float(pricing.get("completion", 0) or 0)
|
||||
except (TypeError, ValueError):
|
||||
output_cost = 0.0
|
||||
|
||||
cfg: dict[str, Any] = {
|
||||
"id": _stable_config_id(model_id, id_offset, taken),
|
||||
"name": name,
|
||||
"description": f"{name} via OpenRouter (vision)",
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": model_id,
|
||||
"api_key": api_key,
|
||||
# Pin to OpenRouter's public base URL so a downstream call site
|
||||
# that forgets ``resolve_api_base`` still doesn't inherit
|
||||
# ``AZURE_OPENAI_ENDPOINT`` (defense-in-depth, see
|
||||
# ``provider_api_base`` docstring).
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
"api_version": None,
|
||||
"rpm": free_rpm if tier == "free" else rpm,
|
||||
"tpm": free_tpm if tier == "free" else tpm,
|
||||
"litellm_params": dict(litellm_params),
|
||||
"billing_tier": tier,
|
||||
"quota_reserve_tokens": quota_reserve_tokens,
|
||||
"input_cost_per_token": input_cost or None,
|
||||
"output_cost_per_token": output_cost or None,
|
||||
_OPENROUTER_DYNAMIC_MARKER: True,
|
||||
}
|
||||
configs.append(cfg)
|
||||
|
|
@ -187,6 +539,25 @@ class OpenRouterIntegrationService:
|
|||
self._configs_by_id: dict[int, dict] = {}
|
||||
self._initialized = False
|
||||
self._refresh_task: asyncio.Task | None = None
|
||||
# Last-good per-model health snapshot. Survives across refresh
|
||||
# cycles so a transient OpenRouter /endpoints outage doesn't drop
|
||||
# every cfg back to static-only scoring.
|
||||
# Shape: {model_name: {"gated": bool, "score": float | None}}
|
||||
self._health_cache: dict[str, dict[str, Any]] = {}
|
||||
self._enrich_task: asyncio.Task | None = None
|
||||
# Raw OpenRouter pricing per model_id, captured at the same time
|
||||
# we generate configs. Consumed by ``pricing_registration`` to
|
||||
# teach LiteLLM the per-token cost of every dynamic deployment so
|
||||
# the success-callback can populate ``response_cost`` correctly.
|
||||
self._raw_pricing: dict[str, dict[str, str]] = {}
|
||||
# Cached raw catalogue from the most recent fetch. Image / vision
|
||||
# emitters reuse this to avoid a second network call per surface.
|
||||
self._raw_models: list[dict] = []
|
||||
# Image / vision config caches (only populated when the matching
|
||||
# opt-in flag is true on initialize). Refreshed in lockstep with
|
||||
# the chat catalogue.
|
||||
self._image_configs: list[dict] = []
|
||||
self._vision_configs: list[dict] = []
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "OpenRouterIntegrationService":
|
||||
|
|
@ -216,16 +587,55 @@ class OpenRouterIntegrationService:
|
|||
self._initialized = True
|
||||
return []
|
||||
|
||||
self._raw_models = raw_models
|
||||
self._configs = _generate_configs(raw_models, settings)
|
||||
self._configs_by_id = {c["id"]: c for c in self._configs}
|
||||
self._raw_pricing = _extract_raw_pricing(raw_models)
|
||||
|
||||
# Populate image / vision caches when their opt-in flag is set.
|
||||
# Empty otherwise so the accessors return [] without re-running
|
||||
# filters every refresh.
|
||||
if settings.get("image_generation_enabled"):
|
||||
self._image_configs = _generate_image_gen_configs(raw_models, settings)
|
||||
logger.info(
|
||||
"OpenRouter integration: image-gen emission ON (%d models)",
|
||||
len(self._image_configs),
|
||||
)
|
||||
else:
|
||||
self._image_configs = []
|
||||
|
||||
if settings.get("vision_enabled"):
|
||||
self._vision_configs = _generate_vision_llm_configs(raw_models, settings)
|
||||
logger.info(
|
||||
"OpenRouter integration: vision LLM emission ON (%d models)",
|
||||
len(self._vision_configs),
|
||||
)
|
||||
else:
|
||||
self._vision_configs = []
|
||||
|
||||
self._initialized = True
|
||||
|
||||
tier_counts = self._tier_counts(self._configs)
|
||||
logger.info(
|
||||
"OpenRouter integration: loaded %d models (IDs %d to %d)",
|
||||
"OpenRouter integration: loaded %d models (free=%d, premium=%d)",
|
||||
len(self._configs),
|
||||
self._configs[0]["id"] if self._configs else 0,
|
||||
self._configs[-1]["id"] if self._configs else 0,
|
||||
tier_counts["free"],
|
||||
tier_counts["premium"],
|
||||
)
|
||||
|
||||
# Schedule the first health-enrichment pass as a deferred task so
|
||||
# cold-start latency is unchanged. Only valid when an event loop is
|
||||
# already running (e.g. FastAPI lifespan); Celery worker init is
|
||||
# fully sync so we silently skip — its first refresh tick (or the
|
||||
# next refresh from the web process) will populate health data.
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
self._enrich_task = loop.create_task(
|
||||
self._enrich_health_safely(self._configs)
|
||||
)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
return self._configs
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
|
|
@ -241,6 +651,8 @@ class OpenRouterIntegrationService:
|
|||
|
||||
new_configs = _generate_configs(raw_models, self._settings)
|
||||
new_by_id = {c["id"]: c for c in new_configs}
|
||||
self._raw_pricing = _extract_raw_pricing(raw_models)
|
||||
self._raw_models = raw_models
|
||||
|
||||
from app.config import config as app_config
|
||||
|
||||
|
|
@ -254,7 +666,263 @@ class OpenRouterIntegrationService:
|
|||
self._configs = new_configs
|
||||
self._configs_by_id = new_by_id
|
||||
|
||||
logger.info("OpenRouter refresh: updated to %d models", len(new_configs))
|
||||
# Image / vision lists are atomic-swapped the same way: filter out
|
||||
# the previous dynamic entries from the live config list and append
|
||||
# the freshly generated ones. No-ops when the opt-in flag is off.
|
||||
if self._settings.get("image_generation_enabled"):
|
||||
new_image = _generate_image_gen_configs(raw_models, self._settings)
|
||||
static_image = [
|
||||
c
|
||||
for c in app_config.GLOBAL_IMAGE_GEN_CONFIGS
|
||||
if not c.get(_OPENROUTER_DYNAMIC_MARKER)
|
||||
]
|
||||
app_config.GLOBAL_IMAGE_GEN_CONFIGS = static_image + new_image
|
||||
self._image_configs = new_image
|
||||
|
||||
if self._settings.get("vision_enabled"):
|
||||
new_vision = _generate_vision_llm_configs(raw_models, self._settings)
|
||||
static_vision = [
|
||||
c
|
||||
for c in app_config.GLOBAL_VISION_LLM_CONFIGS
|
||||
if not c.get(_OPENROUTER_DYNAMIC_MARKER)
|
||||
]
|
||||
app_config.GLOBAL_VISION_LLM_CONFIGS = static_vision + new_vision
|
||||
self._vision_configs = new_vision
|
||||
|
||||
# Catalogue churn invalidates per-config "recently healthy" credit
|
||||
# earned by the previous turn's preflight. Drop the whole table so
|
||||
# the next turn re-probes against the freshly loaded configs.
|
||||
try:
|
||||
from app.services.auto_model_pin_service import clear_healthy
|
||||
|
||||
clear_healthy()
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"OpenRouter refresh: clear_healthy import skipped", exc_info=True
|
||||
)
|
||||
|
||||
tier_counts = self._tier_counts(new_configs)
|
||||
logger.info(
|
||||
"OpenRouter refresh: updated to %d models (free=%d, premium=%d)",
|
||||
len(new_configs),
|
||||
tier_counts["free"],
|
||||
tier_counts["premium"],
|
||||
)
|
||||
|
||||
# Re-blend health scores against the freshly fetched catalogue. Also
|
||||
# re-stamps health for any YAML-curated cfg with provider==OPENROUTER
|
||||
# so a hand-picked dead OR model is gated like a dynamic one.
|
||||
await self._enrich_health_safely(static_configs + new_configs, log_summary=True)
|
||||
|
||||
# Re-register LiteLLM pricing for the freshly fetched catalogue
|
||||
# so newly added OR models bill correctly on their first call.
|
||||
# Runs before the router rebuild because the router may issue
|
||||
# cost-table lookups during deployment registration.
|
||||
try:
|
||||
from app.services.pricing_registration import (
|
||||
register_pricing_from_global_configs,
|
||||
)
|
||||
|
||||
register_pricing_from_global_configs()
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"OpenRouter refresh: pricing re-registration skipped (%s)", exc
|
||||
)
|
||||
|
||||
# Rebuild the LiteLLM router so freshly fetched configs flow through
|
||||
# (dynamic OR premium entries now opt into the pool, free ones stay
|
||||
# out; a refresh also needs to pick up any static-config edits and
|
||||
# reset cached context-window profiles).
|
||||
try:
|
||||
from app.config import config as _app_config
|
||||
from app.services.llm_router_service import (
|
||||
LLMRouterService,
|
||||
_router_instance_cache as _chat_router_cache,
|
||||
)
|
||||
|
||||
LLMRouterService.rebuild(
|
||||
_app_config.GLOBAL_LLM_CONFIGS,
|
||||
getattr(_app_config, "ROUTER_SETTINGS", None),
|
||||
)
|
||||
_chat_router_cache.clear()
|
||||
except Exception as exc:
|
||||
logger.warning("OpenRouter refresh: router rebuild skipped (%s)", exc)
|
||||
|
||||
@staticmethod
|
||||
def _tier_counts(configs: list[dict]) -> dict[str, int]:
|
||||
counts = {"free": 0, "premium": 0}
|
||||
for cfg in configs:
|
||||
tier = str(cfg.get("billing_tier", "")).lower()
|
||||
if tier in counts:
|
||||
counts[tier] += 1
|
||||
return counts
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Auto (Fastest) health enrichment
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _enrich_health_safely(
|
||||
self, configs: list[dict], *, log_summary: bool = True
|
||||
) -> None:
|
||||
"""Wrapper around ``_enrich_health`` that swallows all errors.
|
||||
|
||||
Health enrichment is best-effort: any failure must leave cfgs in
|
||||
their static-only state and never break refresh / startup.
|
||||
"""
|
||||
try:
|
||||
await self._enrich_health(configs, log_summary=log_summary)
|
||||
except Exception:
|
||||
logger.exception("OpenRouter health enrichment failed")
|
||||
|
||||
async def _enrich_health(
|
||||
self, configs: list[dict], *, log_summary: bool = True
|
||||
) -> None:
|
||||
"""Fetch per-model ``/endpoints`` data for the top OR cfgs and blend
|
||||
the resulting health score into ``cfg["quality_score"]``.
|
||||
|
||||
Bounded fan-out: top-N per tier by ``quality_score_static`` only,
|
||||
with ``asyncio.Semaphore(_HEALTH_ENRICH_CONCURRENCY)`` guarding the
|
||||
outbound HTTP. Misses fall back to a per-model last-good cache; if
|
||||
the failure ratio crosses ``_HEALTH_FAIL_RATIO_FALLBACK`` we keep
|
||||
the entire previous cycle's cache for this run.
|
||||
"""
|
||||
or_cfgs = [
|
||||
c for c in configs if str(c.get("provider", "")).upper() == "OPENROUTER"
|
||||
]
|
||||
if not or_cfgs:
|
||||
return
|
||||
|
||||
premium_pool = sorted(
|
||||
[c for c in or_cfgs if str(c.get("billing_tier", "")).lower() == "premium"],
|
||||
key=lambda c: -int(c.get("quality_score_static") or 0),
|
||||
)[:_HEALTH_ENRICH_TOP_N_PREMIUM]
|
||||
free_pool = sorted(
|
||||
[c for c in or_cfgs if str(c.get("billing_tier", "")).lower() == "free"],
|
||||
key=lambda c: -int(c.get("quality_score_static") or 0),
|
||||
)[:_HEALTH_ENRICH_TOP_N_FREE]
|
||||
# De-duplicate while preserving order: a cfg shouldn't fall in both
|
||||
# tiers, but defensive code is cheap here.
|
||||
seen_ids: set[int] = set()
|
||||
selected: list[dict] = []
|
||||
for cfg in premium_pool + free_pool:
|
||||
cid = int(cfg.get("id", 0))
|
||||
if cid in seen_ids:
|
||||
continue
|
||||
seen_ids.add(cid)
|
||||
selected.append(cfg)
|
||||
|
||||
if not selected:
|
||||
return
|
||||
|
||||
api_key = str(self._settings.get("api_key") or "")
|
||||
semaphore = asyncio.Semaphore(_HEALTH_ENRICH_CONCURRENCY)
|
||||
|
||||
async with httpx.AsyncClient(timeout=_HEALTH_FETCH_TIMEOUT_SEC) as client:
|
||||
results = await asyncio.gather(
|
||||
*(
|
||||
self._fetch_endpoints(client, semaphore, api_key, cfg)
|
||||
for cfg in selected
|
||||
)
|
||||
)
|
||||
|
||||
fail_count = sum(1 for _, _, err in results if err is not None)
|
||||
fail_ratio = fail_count / len(results) if results else 0.0
|
||||
degraded = fail_ratio >= _HEALTH_FAIL_RATIO_FALLBACK
|
||||
if degraded:
|
||||
logger.warning(
|
||||
"auto_pin_health_enrich_degraded fail_ratio=%.2f total=%d "
|
||||
"using_last_good_cache=true",
|
||||
fail_ratio,
|
||||
len(results),
|
||||
)
|
||||
|
||||
# Per-cfg health update.
|
||||
for cfg, endpoints, err in results:
|
||||
model_name = str(cfg.get("model_name", ""))
|
||||
if not degraded and err is None and endpoints is not None:
|
||||
gated, h_score = aggregate_health(endpoints)
|
||||
cfg["health_gated"] = bool(gated)
|
||||
cfg["quality_score_health"] = h_score
|
||||
self._health_cache[model_name] = {
|
||||
"gated": bool(gated),
|
||||
"score": h_score,
|
||||
}
|
||||
else:
|
||||
cached = self._health_cache.get(model_name)
|
||||
if cached is not None:
|
||||
cfg["health_gated"] = bool(cached.get("gated", False))
|
||||
cfg["quality_score_health"] = cached.get("score")
|
||||
# else: keep current values (initial defaults from
|
||||
# _generate_configs / load_global_llm_configs).
|
||||
|
||||
# Blend health into the final score for every OR cfg, including
|
||||
# those outside the enriched top-N (they fall through to static).
|
||||
gated_count = 0
|
||||
by_provider: dict[str, int] = {}
|
||||
for cfg in or_cfgs:
|
||||
static_q = int(cfg.get("quality_score_static") or 0)
|
||||
h = cfg.get("quality_score_health")
|
||||
if h is not None and not cfg.get("health_gated"):
|
||||
blended = (
|
||||
_HEALTH_BLEND_WEIGHT * float(h)
|
||||
+ (1 - _HEALTH_BLEND_WEIGHT) * static_q
|
||||
)
|
||||
cfg["quality_score"] = round(blended)
|
||||
else:
|
||||
cfg["quality_score"] = static_q
|
||||
|
||||
if cfg.get("health_gated"):
|
||||
gated_count += 1
|
||||
model_id = str(cfg.get("model_name", ""))
|
||||
provider_slug = (
|
||||
model_id.split("/", 1)[0] if "/" in model_id else "unknown"
|
||||
)
|
||||
by_provider[provider_slug] = by_provider.get(provider_slug, 0) + 1
|
||||
|
||||
if log_summary:
|
||||
logger.info(
|
||||
"auto_pin_health_gated count=%d by_provider=%s fail_ratio=%.2f "
|
||||
"total_enriched=%d",
|
||||
gated_count,
|
||||
dict(sorted(by_provider.items(), key=lambda kv: -kv[1])),
|
||||
fail_ratio,
|
||||
len(selected),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _fetch_endpoints(
|
||||
client: httpx.AsyncClient,
|
||||
semaphore: asyncio.Semaphore,
|
||||
api_key: str,
|
||||
cfg: dict,
|
||||
) -> tuple[dict, list[dict] | None, Exception | None]:
|
||||
"""Fetch ``/api/v1/models/{id}/endpoints`` for one cfg.
|
||||
|
||||
Returns ``(cfg, endpoints, err)`` so the caller can keep batched
|
||||
results aligned with their cfgs without raising.
|
||||
"""
|
||||
model_id = str(cfg.get("model_name", ""))
|
||||
if not model_id:
|
||||
return cfg, None, ValueError("missing model_name")
|
||||
|
||||
url = OPENROUTER_ENDPOINTS_URL_TEMPLATE.format(model_id=model_id)
|
||||
headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
|
||||
|
||||
async with semaphore:
|
||||
try:
|
||||
resp = await client.get(url, headers=headers)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
except Exception as exc:
|
||||
return cfg, None, exc
|
||||
|
||||
payload = data.get("data") if isinstance(data, dict) else None
|
||||
if not isinstance(payload, dict):
|
||||
return cfg, None, ValueError("malformed endpoints payload")
|
||||
endpoints = payload.get("endpoints")
|
||||
if not isinstance(endpoints, list):
|
||||
return cfg, [], None
|
||||
return cfg, endpoints, None
|
||||
|
||||
async def _refresh_loop(self, interval_hours: float) -> None:
|
||||
interval_sec = interval_hours * 3600
|
||||
|
|
@ -289,3 +957,34 @@ class OpenRouterIntegrationService:
|
|||
|
||||
def get_config_by_id(self, config_id: int) -> dict | None:
|
||||
return self._configs_by_id.get(config_id)
|
||||
|
||||
def get_image_generation_configs(self) -> list[dict]:
|
||||
"""Return the dynamic OpenRouter image-generation configs (empty
|
||||
list when the ``image_generation_enabled`` flag is off).
|
||||
|
||||
Each entry already has ``billing_tier`` derived per-model from
|
||||
OpenRouter's signals and is shaped to drop directly into
|
||||
``Config.GLOBAL_IMAGE_GEN_CONFIGS``.
|
||||
"""
|
||||
return list(self._image_configs)
|
||||
|
||||
def get_vision_llm_configs(self) -> list[dict]:
|
||||
"""Return the dynamic OpenRouter vision-LLM configs (empty list
|
||||
when the ``vision_enabled`` flag is off).
|
||||
|
||||
Each entry exposes ``input_cost_per_token`` / ``output_cost_per_token``
|
||||
so ``pricing_registration`` can teach LiteLLM the cost of these
|
||||
models the same way it does for chat — which keeps the billable
|
||||
wrapper able to debit accurate micro-USD on a vision call.
|
||||
"""
|
||||
return list(self._vision_configs)
|
||||
|
||||
def get_raw_pricing(self) -> dict[str, dict[str, str]]:
|
||||
"""Return the cached raw OpenRouter pricing map.
|
||||
|
||||
Shape: ``{model_id: {"prompt": str, "completion": str}}``. The
|
||||
values are the strings OpenRouter publishes (USD per token),
|
||||
never converted to floats here so the caller can decide how to
|
||||
handle malformed or unset entries.
|
||||
"""
|
||||
return dict(self._raw_pricing)
|
||||
|
|
|
|||
274
surfsense_backend/app/services/pricing_registration.py
Normal file
274
surfsense_backend/app/services/pricing_registration.py
Normal file
|
|
@ -0,0 +1,274 @@
|
|||
"""
|
||||
Pricing registration with LiteLLM.
|
||||
|
||||
Many models reach our LiteLLM callback without LiteLLM knowing their
|
||||
per-token cost — namely:
|
||||
|
||||
* The ~300 dynamic OpenRouter deployments (their pricing only lives on
|
||||
OpenRouter's ``/api/v1/models`` payload, never in LiteLLM's published
|
||||
pricing table).
|
||||
* Static YAML deployments whose ``base_model`` name is operator-defined
|
||||
(e.g. custom Azure deployment names like ``gpt-5.4``) and therefore
|
||||
not in LiteLLM's table either.
|
||||
|
||||
Without registration, ``kwargs["response_cost"]`` is 0 for those calls
|
||||
and the user gets billed nothing — a fail-safe but wrong answer for a
|
||||
cost-based credit system. This module runs once at startup, after the
|
||||
OpenRouter integration has fetched its catalogue, and registers each
|
||||
known model's pricing with ``litellm.register_model()`` under multiple
|
||||
plausible alias keys (LiteLLM's cost lookup may use any of them
|
||||
depending on whether the call went through the Router, ChatLiteLLM,
|
||||
or a direct ``acompletion``).
|
||||
|
||||
Operators who run a custom Azure deployment whose ``base_model`` name
|
||||
isn't in LiteLLM's table can declare per-token pricing inline in
|
||||
``global_llm_config.yaml`` via ``input_cost_per_token`` and
|
||||
``output_cost_per_token`` (USD per token, e.g. ``0.000002``). Without
|
||||
that declaration the model's calls debit 0 — never overbilled.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _safe_float(value: Any) -> float:
|
||||
"""Return ``float(value)`` if it parses to a positive number, else 0.0."""
|
||||
if value is None:
|
||||
return 0.0
|
||||
try:
|
||||
f = float(value)
|
||||
except (TypeError, ValueError):
|
||||
return 0.0
|
||||
return f if f > 0 else 0.0
|
||||
|
||||
|
||||
def _alias_set_for_openrouter(model_id: str) -> list[str]:
|
||||
"""Return the alias keys to register an OpenRouter model under.
|
||||
|
||||
LiteLLM's cost-callback lookup key varies by call path:
|
||||
- Router with ``model="openrouter/X"`` → kwargs["model"] is
|
||||
typically ``openrouter/X``.
|
||||
- LiteLLM's own provider routing may strip the prefix and pass the
|
||||
bare ``X`` to the cost-table lookup.
|
||||
Registering under both keeps the lookup hermetic regardless of
|
||||
which path the call took.
|
||||
"""
|
||||
aliases = [f"openrouter/{model_id}", model_id]
|
||||
return list(dict.fromkeys(a for a in aliases if a))
|
||||
|
||||
|
||||
def _alias_set_for_yaml(provider: str, model_name: str, base_model: str) -> list[str]:
|
||||
"""Return the alias keys to register a static YAML deployment under.
|
||||
|
||||
Same reasoning as the OpenRouter set: cover the bare ``base_model``,
|
||||
the ``<provider>/<model>`` form LiteLLM Router constructs, and the
|
||||
bare ``model_name`` because callbacks sometimes see whichever was
|
||||
configured first.
|
||||
"""
|
||||
provider_lower = (provider or "").lower()
|
||||
aliases: list[str] = []
|
||||
if base_model:
|
||||
aliases.append(base_model)
|
||||
if provider_lower and base_model:
|
||||
aliases.append(f"{provider_lower}/{base_model}")
|
||||
if model_name and model_name != base_model:
|
||||
aliases.append(model_name)
|
||||
if provider_lower and model_name and model_name != base_model:
|
||||
aliases.append(f"{provider_lower}/{model_name}")
|
||||
# Azure deployments often surface as "azure/<name>"; normalise the
|
||||
# ``azure_openai`` provider slug to the LiteLLM-canonical ``azure``.
|
||||
if provider_lower == "azure_openai":
|
||||
if base_model:
|
||||
aliases.append(f"azure/{base_model}")
|
||||
if model_name and model_name != base_model:
|
||||
aliases.append(f"azure/{model_name}")
|
||||
return list(dict.fromkeys(a for a in aliases if a))
|
||||
|
||||
|
||||
def _register(
|
||||
aliases: list[str],
|
||||
*,
|
||||
input_cost: float,
|
||||
output_cost: float,
|
||||
provider: str,
|
||||
mode: str = "chat",
|
||||
) -> int:
|
||||
"""Register a single pricing entry under every alias in ``aliases``.
|
||||
|
||||
Returns the count of aliases successfully registered.
|
||||
"""
|
||||
payload: dict[str, dict[str, Any]] = {}
|
||||
for alias in aliases:
|
||||
payload[alias] = {
|
||||
"input_cost_per_token": input_cost,
|
||||
"output_cost_per_token": output_cost,
|
||||
"litellm_provider": provider,
|
||||
"mode": mode,
|
||||
}
|
||||
if not payload:
|
||||
return 0
|
||||
try:
|
||||
litellm.register_model(payload)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"[PricingRegistration] register_model failed for aliases=%s: %s",
|
||||
aliases,
|
||||
exc,
|
||||
)
|
||||
return 0
|
||||
return len(payload)
|
||||
|
||||
|
||||
def _register_chat_shape_configs(
|
||||
configs: list[dict],
|
||||
*,
|
||||
or_pricing: dict[str, dict[str, str]],
|
||||
label: str,
|
||||
) -> tuple[int, int, int, list[str]]:
|
||||
"""Common loop that registers per-token pricing for a list of "chat-shape"
|
||||
configs (chat or vision LLM — both use ``input_cost_per_token`` /
|
||||
``output_cost_per_token`` and the LiteLLM ``mode="chat"`` cost shape).
|
||||
|
||||
Returns ``(registered_models, registered_aliases, skipped, sample_keys)``.
|
||||
"""
|
||||
registered_models = 0
|
||||
registered_aliases = 0
|
||||
skipped_no_pricing = 0
|
||||
sample_keys: list[str] = []
|
||||
|
||||
for cfg in configs:
|
||||
provider = str(cfg.get("provider") or "").upper()
|
||||
model_name = str(cfg.get("model_name") or "").strip()
|
||||
litellm_params = cfg.get("litellm_params") or {}
|
||||
base_model = str(litellm_params.get("base_model") or model_name).strip()
|
||||
|
||||
if provider == "OPENROUTER":
|
||||
entry = or_pricing.get(model_name)
|
||||
if entry:
|
||||
input_cost = _safe_float(entry.get("prompt"))
|
||||
output_cost = _safe_float(entry.get("completion"))
|
||||
else:
|
||||
# Vision configs from ``_generate_vision_llm_configs``
|
||||
# carry their pricing inline because the OpenRouter
|
||||
# raw-pricing cache is keyed by chat-catalogue model_id;
|
||||
# vision flows pick up the inline values here.
|
||||
input_cost = _safe_float(cfg.get("input_cost_per_token"))
|
||||
output_cost = _safe_float(cfg.get("output_cost_per_token"))
|
||||
if input_cost == 0.0 and output_cost == 0.0:
|
||||
skipped_no_pricing += 1
|
||||
continue
|
||||
aliases = _alias_set_for_openrouter(model_name)
|
||||
count = _register(
|
||||
aliases,
|
||||
input_cost=input_cost,
|
||||
output_cost=output_cost,
|
||||
provider="openrouter",
|
||||
)
|
||||
if count > 0:
|
||||
registered_models += 1
|
||||
registered_aliases += count
|
||||
if len(sample_keys) < 6:
|
||||
sample_keys.extend(aliases[:2])
|
||||
continue
|
||||
|
||||
input_cost = _safe_float(
|
||||
cfg.get("input_cost_per_token")
|
||||
or litellm_params.get("input_cost_per_token")
|
||||
)
|
||||
output_cost = _safe_float(
|
||||
cfg.get("output_cost_per_token")
|
||||
or litellm_params.get("output_cost_per_token")
|
||||
)
|
||||
if input_cost == 0.0 and output_cost == 0.0:
|
||||
skipped_no_pricing += 1
|
||||
continue
|
||||
aliases = _alias_set_for_yaml(provider, model_name, base_model)
|
||||
provider_slug = "azure" if provider == "AZURE_OPENAI" else provider.lower()
|
||||
count = _register(
|
||||
aliases,
|
||||
input_cost=input_cost,
|
||||
output_cost=output_cost,
|
||||
provider=provider_slug,
|
||||
)
|
||||
if count > 0:
|
||||
registered_models += 1
|
||||
registered_aliases += count
|
||||
if len(sample_keys) < 6:
|
||||
sample_keys.extend(aliases[:2])
|
||||
|
||||
logger.info(
|
||||
"[PricingRegistration:%s] registered pricing for %d models (%d aliases); "
|
||||
"%d configs had no pricing data; sample registered keys=%s",
|
||||
label,
|
||||
registered_models,
|
||||
registered_aliases,
|
||||
skipped_no_pricing,
|
||||
sample_keys,
|
||||
)
|
||||
return registered_models, registered_aliases, skipped_no_pricing, sample_keys
|
||||
|
||||
|
||||
def register_pricing_from_global_configs() -> None:
|
||||
"""Register pricing for every known LLM deployment with LiteLLM.
|
||||
|
||||
Walks ``config.GLOBAL_LLM_CONFIGS`` *and* ``config.GLOBAL_VISION_LLM_CONFIGS``
|
||||
so vision calls (during indexing) can resolve cost the same way chat
|
||||
calls do — namely:
|
||||
|
||||
1. ``OPENROUTER``: pulls the cached raw pricing from
|
||||
``OpenRouterIntegrationService`` (populated during its own
|
||||
startup fetch) and converts the per-token strings to floats. For
|
||||
vision configs that carry pricing inline (``input_cost_per_token`` /
|
||||
``output_cost_per_token`` set on the cfg itself) we fall back to
|
||||
those values when the OR cache misses the model.
|
||||
2. Anything else: looks for operator-declared
|
||||
``input_cost_per_token`` / ``output_cost_per_token`` on the YAML
|
||||
config block (top-level or nested under ``litellm_params``).
|
||||
|
||||
**Image generation is intentionally NOT registered here.** The cost
|
||||
shape for image-gen is per-image (``output_cost_per_image``), not
|
||||
per-token, and LiteLLM's ``register_model`` doesn't accept those
|
||||
keys via the chat-cost path. OpenRouter image-gen models populate
|
||||
``response_cost`` directly from their response header instead, and
|
||||
Azure-native image-gen models are already in LiteLLM's cost map.
|
||||
|
||||
Calls without a resolved pair of costs are skipped, not registered
|
||||
with zeros — operators who forget pricing get a "$0 debit" warning
|
||||
in ``TokenTrackingCallback`` rather than silently overwriting any
|
||||
pricing LiteLLM might know natively.
|
||||
"""
|
||||
from app.config import config as app_config
|
||||
|
||||
chat_configs: list[dict] = list(getattr(app_config, "GLOBAL_LLM_CONFIGS", []) or [])
|
||||
vision_configs: list[dict] = list(
|
||||
getattr(app_config, "GLOBAL_VISION_LLM_CONFIGS", []) or []
|
||||
)
|
||||
if not chat_configs and not vision_configs:
|
||||
logger.info("[PricingRegistration] no global configs to register")
|
||||
return
|
||||
|
||||
or_pricing: dict[str, dict[str, str]] = {}
|
||||
try:
|
||||
from app.services.openrouter_integration_service import (
|
||||
OpenRouterIntegrationService,
|
||||
)
|
||||
|
||||
if OpenRouterIntegrationService.is_initialized():
|
||||
or_pricing = OpenRouterIntegrationService.get_instance().get_raw_pricing()
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"[PricingRegistration] OpenRouter pricing not available yet: %s", exc
|
||||
)
|
||||
|
||||
if chat_configs:
|
||||
_register_chat_shape_configs(chat_configs, or_pricing=or_pricing, label="chat")
|
||||
if vision_configs:
|
||||
_register_chat_shape_configs(
|
||||
vision_configs, or_pricing=or_pricing, label="vision"
|
||||
)
|
||||
106
surfsense_backend/app/services/provider_api_base.py
Normal file
106
surfsense_backend/app/services/provider_api_base.py
Normal file
|
|
@ -0,0 +1,106 @@
|
|||
"""Provider-aware ``api_base`` resolution shared by chat / image-gen / vision.
|
||||
|
||||
LiteLLM falls back to the module-global ``litellm.api_base`` when an
|
||||
individual call doesn't pass one, which silently inherits provider-agnostic
|
||||
env vars like ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE``. Without an
|
||||
explicit ``api_base``, an ``openrouter/<model>`` request can end up at an
|
||||
Azure endpoint and 404 with ``Resource not found`` (real reproducer:
|
||||
[litellm/llms/openrouter/image_generation/transformation.py:242-263] appends
|
||||
``/chat/completions`` to whatever inherited base it gets, regardless of
|
||||
provider).
|
||||
|
||||
The chat router has had this defense for a while
|
||||
(``llm_router_service.py:466-478``). This module hoists the maps + cascade
|
||||
into a tiny standalone helper so vision and image-gen can share the same
|
||||
source of truth without an inter-service circular import.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
PROVIDER_DEFAULT_API_BASE: dict[str, str] = {
|
||||
"openrouter": "https://openrouter.ai/api/v1",
|
||||
"groq": "https://api.groq.com/openai/v1",
|
||||
"mistral": "https://api.mistral.ai/v1",
|
||||
"perplexity": "https://api.perplexity.ai",
|
||||
"xai": "https://api.x.ai/v1",
|
||||
"cerebras": "https://api.cerebras.ai/v1",
|
||||
"deepinfra": "https://api.deepinfra.com/v1/openai",
|
||||
"fireworks_ai": "https://api.fireworks.ai/inference/v1",
|
||||
"together_ai": "https://api.together.xyz/v1",
|
||||
"anyscale": "https://api.endpoints.anyscale.com/v1",
|
||||
"cometapi": "https://api.cometapi.com/v1",
|
||||
"sambanova": "https://api.sambanova.ai/v1",
|
||||
}
|
||||
"""Default ``api_base`` per LiteLLM provider prefix (lowercase).
|
||||
|
||||
Only providers with a well-known, stable public base URL are listed —
|
||||
self-hosted / BYO-endpoint providers (ollama, custom, bedrock, vertex_ai,
|
||||
huggingface, databricks, cloudflare, replicate) are intentionally omitted
|
||||
so their existing config-driven behaviour is preserved."""
|
||||
|
||||
|
||||
PROVIDER_KEY_DEFAULT_API_BASE: dict[str, str] = {
|
||||
"DEEPSEEK": "https://api.deepseek.com/v1",
|
||||
"ALIBABA_QWEN": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
|
||||
"MOONSHOT": "https://api.moonshot.ai/v1",
|
||||
"ZHIPU": "https://open.bigmodel.cn/api/paas/v4",
|
||||
"MINIMAX": "https://api.minimax.io/v1",
|
||||
}
|
||||
"""Canonical provider key (uppercase) → base URL.
|
||||
|
||||
Used when the LiteLLM provider prefix is the generic ``openai`` shim but the
|
||||
config's ``provider`` field tells us which API it actually is (DeepSeek,
|
||||
Alibaba, Moonshot, Zhipu, MiniMax all use the ``openai`` prefix but each
|
||||
has its own base URL)."""
|
||||
|
||||
|
||||
def resolve_api_base(
|
||||
*,
|
||||
provider: str | None,
|
||||
provider_prefix: str | None,
|
||||
config_api_base: str | None,
|
||||
) -> str | None:
|
||||
"""Resolve a non-Azure-leaking ``api_base`` for a deployment.
|
||||
|
||||
Cascade (first non-empty wins):
|
||||
1. The config's own ``api_base`` (whitespace-only treated as missing).
|
||||
2. ``PROVIDER_KEY_DEFAULT_API_BASE[provider.upper()]``.
|
||||
3. ``PROVIDER_DEFAULT_API_BASE[provider_prefix.lower()]``.
|
||||
4. ``None`` — caller should NOT set ``api_base`` and let the LiteLLM
|
||||
provider integration apply its own default (e.g. AzureOpenAI's
|
||||
deployment-derived URL, custom provider's per-deployment URL).
|
||||
|
||||
Args:
|
||||
provider: The config's ``provider`` field (e.g. ``"OPENROUTER"``,
|
||||
``"DEEPSEEK"``). Case-insensitive.
|
||||
provider_prefix: The LiteLLM model-string prefix the same call
|
||||
site builds for the model id (e.g. ``"openrouter"``,
|
||||
``"groq"``). Case-insensitive.
|
||||
config_api_base: ``api_base`` from the global YAML / DB row /
|
||||
OpenRouter dynamic config. Empty / whitespace-only means
|
||||
"missing" — the resolver still applies the cascade.
|
||||
|
||||
Returns:
|
||||
A URL string, or ``None`` if no default applies for this provider.
|
||||
"""
|
||||
if config_api_base and config_api_base.strip():
|
||||
return config_api_base
|
||||
|
||||
if provider:
|
||||
key_default = PROVIDER_KEY_DEFAULT_API_BASE.get(provider.upper())
|
||||
if key_default:
|
||||
return key_default
|
||||
|
||||
if provider_prefix:
|
||||
prefix_default = PROVIDER_DEFAULT_API_BASE.get(provider_prefix.lower())
|
||||
if prefix_default:
|
||||
return prefix_default
|
||||
|
||||
return None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PROVIDER_DEFAULT_API_BASE",
|
||||
"PROVIDER_KEY_DEFAULT_API_BASE",
|
||||
"resolve_api_base",
|
||||
]
|
||||
280
surfsense_backend/app/services/provider_capabilities.py
Normal file
280
surfsense_backend/app/services/provider_capabilities.py
Normal file
|
|
@ -0,0 +1,280 @@
|
|||
"""Capability resolution shared by chat / image / vision call sites.
|
||||
|
||||
Why this exists
|
||||
---------------
|
||||
The chat catalog (YAML + dynamic OpenRouter + BYOK DB rows + Auto) needs a
|
||||
single, authoritative answer to one question: *can this chat config accept
|
||||
``image_url`` content blocks?* Without it, the new-chat selector can't badge
|
||||
incompatible models and the streaming task can't fail fast with a friendly
|
||||
error before sending an image to a text-only provider.
|
||||
|
||||
Two functions, two intents:
|
||||
|
||||
- :func:`derive_supports_image_input` — best-effort *True* for catalog and
|
||||
UI surfacing. Default-allow: an unknown / unmapped model is treated as
|
||||
capable so we never lock the user out of a freshly added or
|
||||
third-party-hosted vision model.
|
||||
|
||||
- :func:`is_known_text_only_chat_model` — strict opt-out for the streaming
|
||||
task's safety net. Returns True only when LiteLLM's model map *explicitly*
|
||||
sets ``supports_vision=False`` (or its bare-name variant does). Anything
|
||||
else — missing key, lookup exception, ``supports_vision=True`` — returns
|
||||
False so the request flows through to the provider.
|
||||
|
||||
Implementation rule: only public LiteLLM symbols
|
||||
------------------------------------------------
|
||||
``litellm.supports_vision`` and ``litellm.get_model_info`` are part of the
|
||||
typed module surface (see ``litellm.__init__`` lazy stubs) and are stable
|
||||
across releases. The private ``_is_explicitly_disabled_factory`` and
|
||||
``_get_model_info_helper`` are intentionally avoided so a LiteLLM upgrade
|
||||
can't silently break us.
|
||||
|
||||
Why the previous round's strict YAML opt-in flag failed
|
||||
-------------------------------------------------------
|
||||
``supports_image_input: false`` was the YAML loader's setdefault. Operators
|
||||
maintaining ``global_llm_config.yaml`` never set it, so every Azure / OpenAI
|
||||
YAML chat model — including vision-capable GPT-5.x and GPT-4o — resolved to
|
||||
False and the streaming gate rejected every image turn. Sourcing capability
|
||||
from LiteLLM's authoritative model map (which already says
|
||||
``azure/gpt-5.4 -> supports_vision=true``) removes that operator toil.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Iterable
|
||||
|
||||
import litellm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Provider-name → LiteLLM model-prefix map.
|
||||
#
|
||||
# Owned here because ``app.services.provider_capabilities`` is the
|
||||
# only edge that's safe to call from ``app.config``'s YAML loader at
|
||||
# class-body init time. ``app.agents.new_chat.llm_config`` re-exports
|
||||
# this constant under the historical ``PROVIDER_MAP`` name; placing the
|
||||
# map there directly would re-introduce the
|
||||
# ``app.config -> ... -> app.agents.new_chat.tools.generate_image ->
|
||||
# app.config`` cycle that prompted the move.
|
||||
_PROVIDER_PREFIX_MAP: dict[str, str] = {
|
||||
"OPENAI": "openai",
|
||||
"ANTHROPIC": "anthropic",
|
||||
"GROQ": "groq",
|
||||
"COHERE": "cohere",
|
||||
"GOOGLE": "gemini",
|
||||
"OLLAMA": "ollama_chat",
|
||||
"MISTRAL": "mistral",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"OPENROUTER": "openrouter",
|
||||
"XAI": "xai",
|
||||
"BEDROCK": "bedrock",
|
||||
"VERTEX_AI": "vertex_ai",
|
||||
"TOGETHER_AI": "together_ai",
|
||||
"FIREWORKS_AI": "fireworks_ai",
|
||||
"DEEPSEEK": "openai",
|
||||
"ALIBABA_QWEN": "openai",
|
||||
"MOONSHOT": "openai",
|
||||
"ZHIPU": "openai",
|
||||
"GITHUB_MODELS": "github",
|
||||
"REPLICATE": "replicate",
|
||||
"PERPLEXITY": "perplexity",
|
||||
"ANYSCALE": "anyscale",
|
||||
"DEEPINFRA": "deepinfra",
|
||||
"CEREBRAS": "cerebras",
|
||||
"SAMBANOVA": "sambanova",
|
||||
"AI21": "ai21",
|
||||
"CLOUDFLARE": "cloudflare",
|
||||
"DATABRICKS": "databricks",
|
||||
"COMETAPI": "cometapi",
|
||||
"HUGGINGFACE": "huggingface",
|
||||
"MINIMAX": "openai",
|
||||
"CUSTOM": "custom",
|
||||
}
|
||||
|
||||
|
||||
def _candidate_model_strings(
|
||||
*,
|
||||
provider: str | None,
|
||||
model_name: str | None,
|
||||
base_model: str | None,
|
||||
custom_provider: str | None,
|
||||
) -> list[tuple[str, str | None]]:
|
||||
"""Return ``[(model_string, custom_llm_provider), ...]`` lookup candidates.
|
||||
|
||||
LiteLLM's capability lookup is keyed by ``model`` + (optional)
|
||||
``custom_llm_provider``. Different config sources give us different
|
||||
levels of detail, so we try the most-specific keys first and fall back
|
||||
to bare model names so unannotated entries (e.g. an Azure deployment
|
||||
pointing at ``gpt-5.4`` via ``litellm_params.base_model``) still hit the
|
||||
map. Order matters — the first lookup that returns a definitive answer
|
||||
wins for both helpers.
|
||||
"""
|
||||
candidates: list[tuple[str, str | None]] = []
|
||||
seen: set[tuple[str, str | None]] = set()
|
||||
|
||||
def _add(model: str | None, llm_provider: str | None) -> None:
|
||||
if not model:
|
||||
return
|
||||
key = (model, llm_provider)
|
||||
if key in seen:
|
||||
return
|
||||
seen.add(key)
|
||||
candidates.append(key)
|
||||
|
||||
provider_prefix: str | None = None
|
||||
if provider:
|
||||
provider_prefix = _PROVIDER_PREFIX_MAP.get(provider.upper(), provider.lower())
|
||||
if custom_provider:
|
||||
# ``custom_provider`` overrides everything for CUSTOM/proxy setups.
|
||||
provider_prefix = custom_provider
|
||||
|
||||
primary_model = base_model or model_name
|
||||
bare_model = model_name
|
||||
|
||||
# Most-specific first: provider-prefixed identifier with explicit
|
||||
# custom_llm_provider so LiteLLM won't have to guess the provider via
|
||||
# ``get_llm_provider``.
|
||||
if primary_model and provider_prefix:
|
||||
# e.g. "azure/gpt-5.4" + custom_llm_provider="azure"
|
||||
if "/" in primary_model:
|
||||
_add(primary_model, provider_prefix)
|
||||
else:
|
||||
_add(f"{provider_prefix}/{primary_model}", provider_prefix)
|
||||
|
||||
# Bare base_model (or model_name) with provider hint — handles entries
|
||||
# the upstream map keys without a provider prefix (most ``gpt-*`` and
|
||||
# ``claude-*`` entries do this).
|
||||
if primary_model:
|
||||
_add(primary_model, provider_prefix)
|
||||
|
||||
# Fallback to model_name when base_model differs (e.g. an Azure
|
||||
# deployment whose model_name is the deployment id but base_model is the
|
||||
# canonical OpenAI sku).
|
||||
if bare_model and bare_model != primary_model:
|
||||
if provider_prefix and "/" not in bare_model:
|
||||
_add(f"{provider_prefix}/{bare_model}", provider_prefix)
|
||||
_add(bare_model, provider_prefix)
|
||||
_add(bare_model, None)
|
||||
|
||||
return candidates
|
||||
|
||||
|
||||
def derive_supports_image_input(
|
||||
*,
|
||||
provider: str | None = None,
|
||||
model_name: str | None = None,
|
||||
base_model: str | None = None,
|
||||
custom_provider: str | None = None,
|
||||
openrouter_input_modalities: Iterable[str] | None = None,
|
||||
) -> bool:
|
||||
"""Best-effort capability flag for the new-chat selector and catalog.
|
||||
|
||||
Resolution order (first definitive answer wins):
|
||||
|
||||
1. ``openrouter_input_modalities`` (when provided as a non-empty
|
||||
iterable). OpenRouter exposes ``architecture.input_modalities`` per
|
||||
model and that's the authoritative source for OR dynamic configs.
|
||||
2. ``litellm.supports_vision`` against each candidate identifier from
|
||||
:func:`_candidate_model_strings`. Returns True as soon as any
|
||||
candidate confirms vision support.
|
||||
3. Default ``True`` — the conservative-allow stance. An unknown /
|
||||
newly-added / third-party-hosted model is *not* pre-judged. The
|
||||
streaming safety net (:func:`is_known_text_only_chat_model`) is the
|
||||
only place a False ever blocks; everywhere else, a False here would
|
||||
just hide a usable model from the user.
|
||||
|
||||
Returns:
|
||||
True if the model can plausibly accept image input, False only when
|
||||
OpenRouter explicitly says it can't.
|
||||
"""
|
||||
if openrouter_input_modalities is not None:
|
||||
modalities = list(openrouter_input_modalities)
|
||||
if modalities:
|
||||
return "image" in modalities
|
||||
# Empty list explicitly published by OR — treat as "no image".
|
||||
return False
|
||||
|
||||
for model_string, custom_llm_provider in _candidate_model_strings(
|
||||
provider=provider,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
custom_provider=custom_provider,
|
||||
):
|
||||
try:
|
||||
if litellm.supports_vision(
|
||||
model=model_string, custom_llm_provider=custom_llm_provider
|
||||
):
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"litellm.supports_vision raised for model=%s provider=%s: %s",
|
||||
model_string,
|
||||
custom_llm_provider,
|
||||
exc,
|
||||
)
|
||||
continue
|
||||
|
||||
# Default-allow. ``is_known_text_only_chat_model`` is the strict gate.
|
||||
return True
|
||||
|
||||
|
||||
def is_known_text_only_chat_model(
|
||||
*,
|
||||
provider: str | None = None,
|
||||
model_name: str | None = None,
|
||||
base_model: str | None = None,
|
||||
custom_provider: str | None = None,
|
||||
) -> bool:
|
||||
"""Strict opt-out probe for the streaming-task safety net.
|
||||
|
||||
Returns True only when LiteLLM's model map *explicitly* sets
|
||||
``supports_vision=False`` for at least one candidate identifier. Missing
|
||||
key, lookup exception, or ``supports_vision=True`` all return False so
|
||||
the streaming task lets the request through. This is the inverse-default
|
||||
of :func:`derive_supports_image_input`.
|
||||
|
||||
Why two functions
|
||||
-----------------
|
||||
The selector wants "show me everything that's plausibly capable" —
|
||||
default-allow. The safety net wants "block only when I'm certain it
|
||||
can't" — default-pass. Mixing the two intents in a single function
|
||||
leads to the regression we're fixing here.
|
||||
"""
|
||||
for model_string, custom_llm_provider in _candidate_model_strings(
|
||||
provider=provider,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
custom_provider=custom_provider,
|
||||
):
|
||||
try:
|
||||
info = litellm.get_model_info(
|
||||
model=model_string, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"litellm.get_model_info raised for model=%s provider=%s: %s",
|
||||
model_string,
|
||||
custom_llm_provider,
|
||||
exc,
|
||||
)
|
||||
continue
|
||||
|
||||
# ``ModelInfo`` is a TypedDict (dict at runtime). ``supports_vision``
|
||||
# may be missing, None, True, or False. We only fire on explicit
|
||||
# False — None / missing / True all mean "don't block".
|
||||
try:
|
||||
value = info.get("supports_vision") # type: ignore[union-attr]
|
||||
except AttributeError:
|
||||
value = None
|
||||
if value is False:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
__all__ = [
|
||||
"derive_supports_image_input",
|
||||
"is_known_text_only_chat_model",
|
||||
]
|
||||
380
surfsense_backend/app/services/quality_score.py
Normal file
380
surfsense_backend/app/services/quality_score.py
Normal file
|
|
@ -0,0 +1,380 @@
|
|||
"""Pure-function quality scoring for Auto (Fastest) model selection.
|
||||
|
||||
This module is import-free of any service / request-path dependencies. All
|
||||
numbers are computed once during the OpenRouter refresh tick (or YAML load)
|
||||
and cached on the cfg dict, so the chat hot path only does a precomputed
|
||||
sort and a SHA256 pick.
|
||||
|
||||
Score components (0-100 scale, higher is better):
|
||||
|
||||
* ``static_score_or`` - derived from the bulk ``/api/v1/models`` payload
|
||||
(provider prestige + ``created`` recency + pricing band + context window
|
||||
+ capabilities + narrow tiny/legacy slug penalty).
|
||||
* ``static_score_yaml`` - same shape for hand-curated YAML configs, plus
|
||||
an operator-trust bonus (the operator deliberately picked this model).
|
||||
* ``aggregate_health`` - run on per-model ``/api/v1/models/{id}/endpoints``
|
||||
responses; returns ``(gated, score_or_none)``.
|
||||
|
||||
The blended ``quality_score`` (0.5 * static + 0.5 * health) is computed in
|
||||
:mod:`app.services.openrouter_integration_service` because that's the only
|
||||
caller that sees both halves.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tunables (constants, not flags)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Top-K size for deterministic spread inside the locked tier.
|
||||
_QUALITY_TOP_K: int = 5
|
||||
|
||||
# Hard health gate: any cfg whose best non-null uptime is below this %
|
||||
# is excluded from Auto-mode selection entirely.
|
||||
_HEALTH_GATE_UPTIME_PCT: float = 90.0
|
||||
|
||||
# Health/static blend weight when a cfg has fresh /endpoints data.
|
||||
_HEALTH_BLEND_WEIGHT: float = 0.5
|
||||
|
||||
# Static bonus applied to YAML cfgs because the operator hand-picked them.
|
||||
_OPERATOR_TRUST_BONUS: int = 20
|
||||
|
||||
# /endpoints fan-out is bounded per refresh tick.
|
||||
_HEALTH_ENRICH_TOP_N_PREMIUM: int = 50
|
||||
_HEALTH_ENRICH_TOP_N_FREE: int = 30
|
||||
_HEALTH_ENRICH_CONCURRENCY: int = 15
|
||||
_HEALTH_FETCH_TIMEOUT_SEC: float = 5.0
|
||||
|
||||
# If at least this fraction of /endpoints fetches fail in a refresh cycle,
|
||||
# fall back to the previous cycle's last-good cache instead of writing
|
||||
# partial / stale health values.
|
||||
_HEALTH_FAIL_RATIO_FALLBACK: float = 0.25
|
||||
|
||||
# Narrow tiny/legacy slug penalties only. We deliberately do NOT penalise
|
||||
# ``-nano`` / ``-mini`` / ``-lite`` because modern frontier models ship with
|
||||
# those naming patterns (``gpt-5-mini``, ``gemini-2.5-flash-lite`` etc.) and
|
||||
# blanket-penalising them suppresses high-quality picks.
|
||||
_TINY_LEGACY_PENALTY_PATTERNS: tuple[str, ...] = (
|
||||
"-1b-",
|
||||
"-1.2b-",
|
||||
"-1.5b-",
|
||||
"-2b-",
|
||||
"-3b-",
|
||||
"gemma-3n",
|
||||
"lfm-",
|
||||
"-base",
|
||||
"-distill",
|
||||
":nitro",
|
||||
"-preview",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider prestige tables
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# OpenRouter-side provider slug (the prefix before ``/`` in the model id).
|
||||
# Tiers are coarse: frontier labs > strong open / fast-moving labs >
|
||||
# specialist labs > everything else.
|
||||
PROVIDER_PRESTIGE_OR: dict[str, int] = {
|
||||
# Frontier labs
|
||||
"openai": 50,
|
||||
"anthropic": 50,
|
||||
"google": 50,
|
||||
"x-ai": 50,
|
||||
# Strong open / fast-moving labs
|
||||
"deepseek": 38,
|
||||
"qwen": 38,
|
||||
"meta-llama": 38,
|
||||
"mistralai": 38,
|
||||
"cohere": 38,
|
||||
"nvidia": 38,
|
||||
"alibaba": 38,
|
||||
# Specialist / regional / strong second-tier
|
||||
"microsoft": 28,
|
||||
"01-ai": 28,
|
||||
"minimax": 28,
|
||||
"moonshot": 28,
|
||||
"z-ai": 28,
|
||||
"nousresearch": 28,
|
||||
"ai21": 28,
|
||||
"perplexity": 28,
|
||||
# Smaller / niche providers
|
||||
"liquid": 18,
|
||||
"cognitivecomputations": 18,
|
||||
"venice": 18,
|
||||
"inflection": 18,
|
||||
}
|
||||
|
||||
# YAML provider field (the upstream API shape the operator selected).
|
||||
PROVIDER_PRESTIGE_YAML: dict[str, int] = {
|
||||
"AZURE_OPENAI": 50,
|
||||
"OPENAI": 50,
|
||||
"ANTHROPIC": 50,
|
||||
"GOOGLE": 50,
|
||||
"VERTEX_AI": 50,
|
||||
"GEMINI": 50,
|
||||
"XAI": 50,
|
||||
"MISTRAL": 38,
|
||||
"DEEPSEEK": 38,
|
||||
"COHERE": 38,
|
||||
"GROQ": 30,
|
||||
"TOGETHER_AI": 28,
|
||||
"FIREWORKS_AI": 28,
|
||||
"PERPLEXITY": 28,
|
||||
"MINIMAX": 28,
|
||||
"BEDROCK": 28,
|
||||
"OPENROUTER": 25,
|
||||
"OLLAMA": 12,
|
||||
"CUSTOM": 12,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pure scoring helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Calibrated against the live /api/v1/models bulk dump. Frontier models
|
||||
# released in the last ~6 months (GPT-5 family, Claude 4.x, Gemini 2.5,
|
||||
# Grok 4) score in the 18-20 band; mid-2024 models in the 8-12 band;
|
||||
# anything older trails off.
|
||||
_RECENCY_BANDS_DAYS: tuple[tuple[int, int], ...] = (
|
||||
(60, 20),
|
||||
(180, 16),
|
||||
(365, 12),
|
||||
(540, 9),
|
||||
(730, 6),
|
||||
(1095, 3),
|
||||
)
|
||||
|
||||
|
||||
def created_recency_signal(created_ts: int | None, now_ts: int) -> int:
|
||||
"""Return 0-20 based on how recently the model was published.
|
||||
|
||||
Uses the OpenRouter ``created`` Unix timestamp (or any equivalent for
|
||||
YAML cfgs). Models without a usable timestamp get 0 (we don't penalise,
|
||||
we just don't reward).
|
||||
"""
|
||||
if created_ts is None or created_ts <= 0 or now_ts <= 0:
|
||||
return 0
|
||||
age_days = max(0, (now_ts - int(created_ts)) // 86_400)
|
||||
for cutoff, score in _RECENCY_BANDS_DAYS:
|
||||
if age_days <= cutoff:
|
||||
return score
|
||||
return 0
|
||||
|
||||
|
||||
def pricing_band(
|
||||
prompt: str | float | int | None,
|
||||
completion: str | float | int | None,
|
||||
) -> int:
|
||||
"""Return 0-15 based on combined prompt+completion cost per 1M tokens.
|
||||
|
||||
Higher-priced models tend to be the larger / more capable ones. A free
|
||||
model returns 0 (we use other signals to rank free-vs-free instead).
|
||||
Uncoercible inputs are treated as 0 rather than raising.
|
||||
"""
|
||||
|
||||
def _to_float(value) -> float:
|
||||
if value is None:
|
||||
return 0.0
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return 0.0
|
||||
|
||||
p = _to_float(prompt)
|
||||
c = _to_float(completion)
|
||||
total_per_million = (p + c) * 1_000_000
|
||||
|
||||
if total_per_million >= 20.0:
|
||||
return 15
|
||||
if total_per_million >= 5.0:
|
||||
return 12
|
||||
if total_per_million >= 1.0:
|
||||
return 9
|
||||
if total_per_million >= 0.3:
|
||||
return 6
|
||||
if total_per_million >= 0.05:
|
||||
return 4
|
||||
if total_per_million > 0.0:
|
||||
return 2
|
||||
return 0
|
||||
|
||||
|
||||
def context_signal(ctx: int | None) -> int:
|
||||
"""Return 0-10 based on the model's context window."""
|
||||
if not ctx or ctx <= 0:
|
||||
return 0
|
||||
if ctx >= 1_000_000:
|
||||
return 10
|
||||
if ctx >= 400_000:
|
||||
return 8
|
||||
if ctx >= 200_000:
|
||||
return 6
|
||||
if ctx >= 128_000:
|
||||
return 4
|
||||
if ctx >= 100_000:
|
||||
return 2
|
||||
return 0
|
||||
|
||||
|
||||
def capabilities_signal(supported_parameters: list[str] | None) -> int:
|
||||
"""Return 0-5 for capabilities that matter for our agent flows."""
|
||||
if not supported_parameters:
|
||||
return 0
|
||||
params = set(supported_parameters)
|
||||
score = 0
|
||||
if "tools" in params:
|
||||
score += 2
|
||||
if "structured_outputs" in params or "response_format" in params:
|
||||
score += 2
|
||||
if "reasoning" in params or "include_reasoning" in params:
|
||||
score += 1
|
||||
return min(score, 5)
|
||||
|
||||
|
||||
def slug_penalty(model_id: str) -> int:
|
||||
"""Return a non-positive number; matches the narrow tiny/legacy patterns."""
|
||||
if not model_id:
|
||||
return 0
|
||||
needle = model_id.lower()
|
||||
for pattern in _TINY_LEGACY_PENALTY_PATTERNS:
|
||||
if pattern in needle:
|
||||
return -10
|
||||
return 0
|
||||
|
||||
|
||||
def _provider_prestige_or(model_id: str) -> int:
|
||||
if "/" not in model_id:
|
||||
return 0
|
||||
slug = model_id.split("/", 1)[0].lower()
|
||||
return PROVIDER_PRESTIGE_OR.get(slug, 15)
|
||||
|
||||
|
||||
def static_score_or(or_model: dict, *, now_ts: int) -> int:
|
||||
"""Score a raw OpenRouter ``/api/v1/models`` entry on a 0-100 scale."""
|
||||
model_id = str(or_model.get("id", ""))
|
||||
pricing = or_model.get("pricing") or {}
|
||||
|
||||
score = (
|
||||
_provider_prestige_or(model_id)
|
||||
+ created_recency_signal(or_model.get("created"), now_ts)
|
||||
+ pricing_band(pricing.get("prompt"), pricing.get("completion"))
|
||||
+ context_signal(or_model.get("context_length"))
|
||||
+ capabilities_signal(or_model.get("supported_parameters"))
|
||||
+ slug_penalty(model_id)
|
||||
)
|
||||
return max(0, min(100, int(score)))
|
||||
|
||||
|
||||
def static_score_yaml(cfg: dict) -> int:
|
||||
"""Score a YAML-curated cfg on a 0-100 scale.
|
||||
|
||||
Includes ``_OPERATOR_TRUST_BONUS`` because the operator deliberately
|
||||
listed this model. Pricing / context fall through to lazy ``litellm``
|
||||
lookups; failures are silent (we just lose those sub-points).
|
||||
"""
|
||||
provider = str(cfg.get("provider", "")).upper()
|
||||
base = PROVIDER_PRESTIGE_YAML.get(provider, 15)
|
||||
|
||||
model_name = cfg.get("model_name") or ""
|
||||
litellm_params = cfg.get("litellm_params") or {}
|
||||
lookup_name = (
|
||||
litellm_params.get("base_model") or litellm_params.get("model") or model_name
|
||||
)
|
||||
|
||||
ctx = 0
|
||||
p_cost: float = 0.0
|
||||
c_cost: float = 0.0
|
||||
try:
|
||||
from litellm import get_model_info # lazy: avoid cold-import cost
|
||||
|
||||
info = get_model_info(lookup_name) or {}
|
||||
ctx = int(info.get("max_input_tokens") or info.get("max_tokens") or 0)
|
||||
p_cost = float(info.get("input_cost_per_token") or 0.0)
|
||||
c_cost = float(info.get("output_cost_per_token") or 0.0)
|
||||
except Exception:
|
||||
# Unknown to litellm — that's fine for prestige+operator-bonus weighting.
|
||||
pass
|
||||
|
||||
score = (
|
||||
base
|
||||
+ _OPERATOR_TRUST_BONUS
|
||||
+ pricing_band(p_cost, c_cost)
|
||||
+ context_signal(ctx)
|
||||
+ slug_penalty(str(model_name))
|
||||
)
|
||||
return max(0, min(100, int(score)))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Health aggregation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _coerce_pct(value) -> float | None:
|
||||
try:
|
||||
if value is None:
|
||||
return None
|
||||
f = float(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
if f < 0:
|
||||
return None
|
||||
# OpenRouter reports uptime as a 0-1 fraction; some endpoints surface it
|
||||
# as a 0-100 percentage. Normalise.
|
||||
return f * 100.0 if f <= 1.0 else f
|
||||
|
||||
|
||||
def _best_uptime(endpoints: list[dict]) -> tuple[float | None, str | None]:
|
||||
"""Pick the best (highest) non-null uptime across all endpoints.
|
||||
|
||||
Window preference: ``uptime_last_30m`` > ``uptime_last_1d`` >
|
||||
``uptime_last_5m``. Returns ``(uptime_pct, window_used)``.
|
||||
"""
|
||||
for window in ("uptime_last_30m", "uptime_last_1d", "uptime_last_5m"):
|
||||
values = [_coerce_pct(ep.get(window)) for ep in endpoints]
|
||||
values = [v for v in values if v is not None]
|
||||
if values:
|
||||
return max(values), window
|
||||
return None, None
|
||||
|
||||
|
||||
def aggregate_health(endpoints: list[dict]) -> tuple[bool, float | None]:
|
||||
"""Aggregate a model's per-endpoint health into ``(gated, score_or_none)``.
|
||||
|
||||
Hard gate (returns ``(True, None)``):
|
||||
* ``endpoints`` empty,
|
||||
* no endpoint reports ``status == 0`` (OK), or
|
||||
* best non-null uptime below ``_HEALTH_GATE_UPTIME_PCT``.
|
||||
|
||||
On a pass, returns a 0-100 health score blending uptime, status, and a
|
||||
freshness-weighted recent uptime sample.
|
||||
"""
|
||||
if not endpoints:
|
||||
return True, None
|
||||
|
||||
any_ok = any(int(ep.get("status", 1)) == 0 for ep in endpoints)
|
||||
if not any_ok:
|
||||
return True, None
|
||||
|
||||
best_uptime, _ = _best_uptime(endpoints)
|
||||
if best_uptime is None or best_uptime < _HEALTH_GATE_UPTIME_PCT:
|
||||
return True, None
|
||||
|
||||
# Freshness term: prefer 5m, fall through to 30m / 1d if 5m is missing.
|
||||
freshness = None
|
||||
for window in ("uptime_last_5m", "uptime_last_30m", "uptime_last_1d"):
|
||||
values = [_coerce_pct(ep.get(window)) for ep in endpoints]
|
||||
values = [v for v in values if v is not None]
|
||||
if values:
|
||||
freshness = max(values)
|
||||
break
|
||||
|
||||
uptime_term = best_uptime
|
||||
status_term = 100.0 if any_ok else 0.0
|
||||
freshness_term = freshness if freshness is not None else best_uptime
|
||||
|
||||
score = 0.50 * uptime_term + 0.30 * status_term + 0.20 * freshness_term
|
||||
return False, max(0.0, min(100.0, score))
|
||||
105
surfsense_backend/app/services/quota_checked_vision_llm.py
Normal file
105
surfsense_backend/app/services/quota_checked_vision_llm.py
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
"""
|
||||
Vision LLM proxy that enforces premium credit quota on every ``ainvoke``.
|
||||
|
||||
Used by :func:`app.services.llm_service.get_vision_llm` so callers in the
|
||||
indexing pipeline (file processors, connector indexers, etl pipeline) can
|
||||
keep invoking the LLM exactly the way they do today — ``await llm.ainvoke(...)``
|
||||
— without threading ``user_id`` through every parser. The wrapper looks like
|
||||
a chat model from the outside; on the inside it routes each call through
|
||||
``billable_call`` so the user's premium credit pool is reserved → finalized
|
||||
or released, and a ``TokenUsage`` audit row is written.
|
||||
|
||||
Free configs are returned unwrapped from ``get_vision_llm`` (they do not
|
||||
need quota enforcement) so this class only ever wraps premium configs.
|
||||
|
||||
Why a wrapper instead of plumbing ``user_id`` through every caller:
|
||||
|
||||
* The indexer ecosystem has 8+ entry points (Google Drive, OneDrive,
|
||||
Dropbox, local-folder, file-processor, ETL pipeline) each calling
|
||||
``parse_with_vision_llm(...)``. Adding a ``user_id`` argument to each is
|
||||
invasive, error-prone, and easy for a future indexer to forget.
|
||||
* Per the design (issue M), we always debit the *search-space owner*, not
|
||||
the triggering user, so ``user_id`` is fully derivable from the search
|
||||
space the caller is already operating on. The wrapper captures it once
|
||||
at construction time.
|
||||
* ``langchain_litellm.ChatLiteLLM`` has no public hook for "before each
|
||||
call run this coroutine"; subclassing isn't safe across versions because
|
||||
it derives from ``BaseChatModel`` which expects specific Pydantic shapes.
|
||||
Composition via attribute proxying (``__getattr__``) is robust to
|
||||
upstream changes — every method other than ``ainvoke`` falls through to
|
||||
the inner LLM unchanged.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from app.services.billable_calls import QuotaInsufficientError, billable_call
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QuotaCheckedVisionLLM:
|
||||
"""Composition wrapper around a langchain chat model that enforces
|
||||
premium credit quota on every ``ainvoke``.
|
||||
|
||||
Anything other than ``ainvoke`` is forwarded to the inner model so
|
||||
``invoke`` (sync), ``astream``, ``with_structured_output``, etc. all
|
||||
still work — they simply bypass quota enforcement, which is fine
|
||||
because the indexing pipeline only ever calls ``ainvoke`` today.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inner_llm: Any,
|
||||
*,
|
||||
user_id: UUID,
|
||||
search_space_id: int,
|
||||
billing_tier: str,
|
||||
base_model: str,
|
||||
quota_reserve_tokens: int | None,
|
||||
usage_type: str = "vision_extraction",
|
||||
) -> None:
|
||||
self._inner = inner_llm
|
||||
self._user_id = user_id
|
||||
self._search_space_id = search_space_id
|
||||
self._billing_tier = billing_tier
|
||||
self._base_model = base_model
|
||||
self._quota_reserve_tokens = quota_reserve_tokens
|
||||
self._usage_type = usage_type
|
||||
|
||||
async def ainvoke(self, input: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Proxied async invoke that runs the underlying call inside
|
||||
``billable_call``.
|
||||
|
||||
Raises:
|
||||
QuotaInsufficientError: when the user has exhausted their
|
||||
premium credit pool. Caller (``etl_pipeline_service._extract_image``)
|
||||
catches this and falls back to the document parser.
|
||||
"""
|
||||
async with billable_call(
|
||||
user_id=self._user_id,
|
||||
search_space_id=self._search_space_id,
|
||||
billing_tier=self._billing_tier,
|
||||
base_model=self._base_model,
|
||||
quota_reserve_tokens=self._quota_reserve_tokens,
|
||||
usage_type=self._usage_type,
|
||||
call_details={"model": self._base_model},
|
||||
):
|
||||
return await self._inner.ainvoke(input, *args, **kwargs)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""Forward everything else (``invoke``, ``astream``, ``bind``,
|
||||
``with_structured_output``, …) to the inner model.
|
||||
|
||||
``__getattr__`` is only consulted when the attribute is *not*
|
||||
already found on the proxy, which is exactly the contract we
|
||||
want — methods we override stay on the proxy, the rest fall
|
||||
through.
|
||||
"""
|
||||
return getattr(self._inner, name)
|
||||
|
||||
|
||||
__all__ = ["QuotaCheckedVisionLLM", "QuotaInsufficientError"]
|
||||
|
|
@ -22,6 +22,71 @@ from app.config import config
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-call reservation estimator (USD micro-units)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Minimum reserve in micros so a user with $0.0001 left can still make a tiny
|
||||
# request, and so models without registered pricing reserve at least
|
||||
# something while the call runs (debited 0 at finalize anyway when their
|
||||
# cost can't be resolved).
|
||||
_QUOTA_MIN_RESERVE_MICROS = 100
|
||||
|
||||
|
||||
def estimate_call_reserve_micros(
|
||||
*,
|
||||
base_model: str,
|
||||
quota_reserve_tokens: int | None,
|
||||
) -> int:
|
||||
"""Return the number of micro-USD to reserve for one premium call.
|
||||
|
||||
Computes a worst-case upper bound from LiteLLM's per-token pricing
|
||||
table:
|
||||
|
||||
reserve_usd ≈ reserve_tokens x (input_cost + output_cost)
|
||||
|
||||
so the math scales with model cost — Claude Opus + 4K reserve_tokens
|
||||
naturally reserves ≈ $0.36, while a cheap model reserves only a few
|
||||
cents. Clamped to ``[_QUOTA_MIN_RESERVE_MICROS, QUOTA_MAX_RESERVE_MICROS]``
|
||||
so a misconfigured "$1000/M" model can't lock the whole balance on
|
||||
one call.
|
||||
|
||||
If ``litellm.get_model_info`` raises (model unknown) we fall back to
|
||||
the floor — 100 micros / $0.0001 — which is enough to gate a sane
|
||||
request without over-reserving for a model whose pricing the
|
||||
operator hasn't declared yet.
|
||||
"""
|
||||
reserve_tokens = quota_reserve_tokens or config.QUOTA_MAX_RESERVE_PER_CALL
|
||||
if reserve_tokens <= 0:
|
||||
reserve_tokens = config.QUOTA_MAX_RESERVE_PER_CALL
|
||||
|
||||
try:
|
||||
from litellm import get_model_info
|
||||
|
||||
info = get_model_info(base_model) if base_model else {}
|
||||
input_cost = float(info.get("input_cost_per_token") or 0.0)
|
||||
output_cost = float(info.get("output_cost_per_token") or 0.0)
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"[quota_reserve] cost lookup failed for base_model=%s: %s",
|
||||
base_model,
|
||||
exc,
|
||||
)
|
||||
input_cost = 0.0
|
||||
output_cost = 0.0
|
||||
|
||||
if input_cost == 0.0 and output_cost == 0.0:
|
||||
return _QUOTA_MIN_RESERVE_MICROS
|
||||
|
||||
reserve_usd = reserve_tokens * (input_cost + output_cost)
|
||||
reserve_micros = round(reserve_usd * 1_000_000)
|
||||
if reserve_micros < _QUOTA_MIN_RESERVE_MICROS:
|
||||
reserve_micros = _QUOTA_MIN_RESERVE_MICROS
|
||||
if reserve_micros > config.QUOTA_MAX_RESERVE_MICROS:
|
||||
reserve_micros = config.QUOTA_MAX_RESERVE_MICROS
|
||||
return reserve_micros
|
||||
|
||||
|
||||
class QuotaScope(StrEnum):
|
||||
ANONYMOUS = "anonymous"
|
||||
PREMIUM = "premium"
|
||||
|
|
@ -444,8 +509,16 @@ class TokenQuotaService:
|
|||
db_session: AsyncSession,
|
||||
user_id: Any,
|
||||
request_id: str,
|
||||
reserve_tokens: int,
|
||||
reserve_micros: int,
|
||||
) -> QuotaResult:
|
||||
"""Reserve ``reserve_micros`` (USD micro-units) from the user's
|
||||
premium credit balance.
|
||||
|
||||
``QuotaResult.used``/``limit``/``reserved``/``remaining`` are
|
||||
all in micro-USD on this code path; callers (chat stream,
|
||||
token-status route, FE display) convert to dollars by dividing
|
||||
by 1_000_000.
|
||||
"""
|
||||
from app.db import User
|
||||
|
||||
user = (
|
||||
|
|
@ -465,11 +538,11 @@ class TokenQuotaService:
|
|||
limit=0,
|
||||
)
|
||||
|
||||
limit = user.premium_tokens_limit
|
||||
used = user.premium_tokens_used
|
||||
reserved = user.premium_tokens_reserved
|
||||
limit = user.premium_credit_micros_limit
|
||||
used = user.premium_credit_micros_used
|
||||
reserved = user.premium_credit_micros_reserved
|
||||
|
||||
effective = used + reserved + reserve_tokens
|
||||
effective = used + reserved + reserve_micros
|
||||
if effective > limit:
|
||||
remaining = max(0, limit - used - reserved)
|
||||
await db_session.rollback()
|
||||
|
|
@ -482,10 +555,10 @@ class TokenQuotaService:
|
|||
remaining=remaining,
|
||||
)
|
||||
|
||||
user.premium_tokens_reserved = reserved + reserve_tokens
|
||||
user.premium_credit_micros_reserved = reserved + reserve_micros
|
||||
await db_session.commit()
|
||||
|
||||
new_reserved = reserved + reserve_tokens
|
||||
new_reserved = reserved + reserve_micros
|
||||
remaining = max(0, limit - used - new_reserved)
|
||||
warning_threshold = int(limit * 0.8)
|
||||
|
||||
|
|
@ -510,9 +583,12 @@ class TokenQuotaService:
|
|||
db_session: AsyncSession,
|
||||
user_id: Any,
|
||||
request_id: str,
|
||||
actual_tokens: int,
|
||||
reserved_tokens: int,
|
||||
actual_micros: int,
|
||||
reserved_micros: int,
|
||||
) -> QuotaResult:
|
||||
"""Settle the reservation: release ``reserved_micros`` and debit
|
||||
``actual_micros`` (the LiteLLM-reported provider cost in micro-USD).
|
||||
"""
|
||||
from app.db import User
|
||||
|
||||
user = (
|
||||
|
|
@ -529,16 +605,18 @@ class TokenQuotaService:
|
|||
allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0
|
||||
)
|
||||
|
||||
user.premium_tokens_reserved = max(
|
||||
0, user.premium_tokens_reserved - reserved_tokens
|
||||
user.premium_credit_micros_reserved = max(
|
||||
0, user.premium_credit_micros_reserved - reserved_micros
|
||||
)
|
||||
user.premium_credit_micros_used = (
|
||||
user.premium_credit_micros_used + actual_micros
|
||||
)
|
||||
user.premium_tokens_used = user.premium_tokens_used + actual_tokens
|
||||
|
||||
await db_session.commit()
|
||||
|
||||
limit = user.premium_tokens_limit
|
||||
used = user.premium_tokens_used
|
||||
reserved = user.premium_tokens_reserved
|
||||
limit = user.premium_credit_micros_limit
|
||||
used = user.premium_credit_micros_used
|
||||
reserved = user.premium_credit_micros_reserved
|
||||
remaining = max(0, limit - used - reserved)
|
||||
|
||||
warning_threshold = int(limit * 0.8)
|
||||
|
|
@ -562,8 +640,13 @@ class TokenQuotaService:
|
|||
async def premium_release(
|
||||
db_session: AsyncSession,
|
||||
user_id: Any,
|
||||
reserved_tokens: int,
|
||||
reserved_micros: int,
|
||||
) -> None:
|
||||
"""Release ``reserved_micros`` previously held by ``premium_reserve``.
|
||||
|
||||
Used when a request fails before finalize (so the reservation
|
||||
doesn't leak credit).
|
||||
"""
|
||||
from app.db import User
|
||||
|
||||
user = (
|
||||
|
|
@ -576,8 +659,8 @@ class TokenQuotaService:
|
|||
.scalar_one_or_none()
|
||||
)
|
||||
if user is not None:
|
||||
user.premium_tokens_reserved = max(
|
||||
0, user.premium_tokens_reserved - reserved_tokens
|
||||
user.premium_credit_micros_reserved = max(
|
||||
0, user.premium_credit_micros_reserved - reserved_micros
|
||||
)
|
||||
await db_session.commit()
|
||||
|
||||
|
|
@ -598,9 +681,9 @@ class TokenQuotaService:
|
|||
allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0
|
||||
)
|
||||
|
||||
limit = user.premium_tokens_limit
|
||||
used = user.premium_tokens_used
|
||||
reserved = user.premium_tokens_reserved
|
||||
limit = user.premium_credit_micros_limit
|
||||
used = user.premium_credit_micros_used
|
||||
reserved = user.premium_credit_micros_reserved
|
||||
remaining = max(0, limit - used - reserved)
|
||||
|
||||
warning_threshold = int(limit * 0.8)
|
||||
|
|
|
|||
|
|
@ -16,11 +16,14 @@ from __future__ import annotations
|
|||
|
||||
import dataclasses
|
||||
import logging
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import litellm
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
|
@ -35,6 +38,8 @@ class TokenCallRecord:
|
|||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
cost_micros: int = 0
|
||||
call_kind: str = "chat"
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -49,6 +54,8 @@ class TurnTokenAccumulator:
|
|||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
cost_micros: int = 0,
|
||||
call_kind: str = "chat",
|
||||
) -> None:
|
||||
self.calls.append(
|
||||
TokenCallRecord(
|
||||
|
|
@ -56,20 +63,28 @@ class TurnTokenAccumulator:
|
|||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cost_micros=cost_micros,
|
||||
call_kind=call_kind,
|
||||
)
|
||||
)
|
||||
|
||||
def per_message_summary(self) -> dict[str, dict[str, int]]:
|
||||
"""Return token counts grouped by model name."""
|
||||
"""Return token counts (and cost) grouped by model name."""
|
||||
by_model: dict[str, dict[str, int]] = {}
|
||||
for c in self.calls:
|
||||
entry = by_model.setdefault(
|
||||
c.model,
|
||||
{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
{
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"cost_micros": 0,
|
||||
},
|
||||
)
|
||||
entry["prompt_tokens"] += c.prompt_tokens
|
||||
entry["completion_tokens"] += c.completion_tokens
|
||||
entry["total_tokens"] += c.total_tokens
|
||||
entry["cost_micros"] += c.cost_micros
|
||||
return by_model
|
||||
|
||||
@property
|
||||
|
|
@ -84,6 +99,21 @@ class TurnTokenAccumulator:
|
|||
def total_completion_tokens(self) -> int:
|
||||
return sum(c.completion_tokens for c in self.calls)
|
||||
|
||||
@property
|
||||
def total_cost_micros(self) -> int:
|
||||
"""Sum of per-call ``cost_micros`` across the entire turn.
|
||||
|
||||
Used by ``stream_new_chat`` to debit a premium turn's actual
|
||||
provider cost (in micro-USD) from the user's premium credit
|
||||
balance. ``cost_micros`` per call is captured by
|
||||
``TokenTrackingCallback.async_log_success_event`` from
|
||||
``kwargs["response_cost"]`` (LiteLLM's auto-calculated cost),
|
||||
with multiple fallback paths so OpenRouter dynamic models and
|
||||
custom Azure deployments still bill correctly when our
|
||||
``pricing_registration`` ran at startup.
|
||||
"""
|
||||
return sum(c.cost_micros for c in self.calls)
|
||||
|
||||
def serialized_calls(self) -> list[dict[str, Any]]:
|
||||
return [dataclasses.asdict(c) for c in self.calls]
|
||||
|
||||
|
|
@ -94,7 +124,14 @@ _turn_accumulator: ContextVar[TurnTokenAccumulator | None] = ContextVar(
|
|||
|
||||
|
||||
def start_turn() -> TurnTokenAccumulator:
|
||||
"""Create a fresh accumulator for the current async context and return it."""
|
||||
"""Create a fresh accumulator for the current async context and return it.
|
||||
|
||||
NOTE: Used by ``stream_new_chat`` for the long-lived chat turn. For
|
||||
short-lived per-call billable wrappers (image generation REST endpoint,
|
||||
vision LLM during indexing) prefer :func:`scoped_turn`, which uses a
|
||||
ContextVar reset token to restore the *previous* accumulator on exit and
|
||||
avoids leaking call records across reservations (issue B).
|
||||
"""
|
||||
acc = TurnTokenAccumulator()
|
||||
_turn_accumulator.set(acc)
|
||||
logger.info("[TokenTracking] start_turn: new accumulator created (id=%s)", id(acc))
|
||||
|
|
@ -105,6 +142,140 @@ def get_current_accumulator() -> TurnTokenAccumulator | None:
|
|||
return _turn_accumulator.get()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def scoped_turn() -> AsyncIterator[TurnTokenAccumulator]:
|
||||
"""Async context manager that scopes a fresh ``TurnTokenAccumulator``
|
||||
for the duration of the ``async with`` block, then *resets* the
|
||||
ContextVar to its previous value on exit.
|
||||
|
||||
This is the safe primitive for per-call billable operations
|
||||
(image generation, vision LLM extraction, podcasts) that may run
|
||||
inside an outer chat turn or be called sequentially from the same
|
||||
background worker. Using ``ContextVar.set`` without ``reset`` (as
|
||||
:func:`start_turn` does) would leak the inner accumulator into the
|
||||
outer scope, causing the outer chat turn to debit cost twice.
|
||||
|
||||
Usage::
|
||||
|
||||
async with scoped_turn() as acc:
|
||||
await llm.ainvoke(...)
|
||||
# acc.total_cost_micros captures cost from the LiteLLM callback
|
||||
# Outer accumulator (if any) is restored here.
|
||||
"""
|
||||
acc = TurnTokenAccumulator()
|
||||
token = _turn_accumulator.set(acc)
|
||||
logger.debug(
|
||||
"[TokenTracking] scoped_turn: enter (acc id=%s, prev token=%s)",
|
||||
id(acc),
|
||||
token,
|
||||
)
|
||||
try:
|
||||
yield acc
|
||||
finally:
|
||||
_turn_accumulator.reset(token)
|
||||
logger.debug(
|
||||
"[TokenTracking] scoped_turn: exit (acc id=%s captured %d call(s), %d micros total)",
|
||||
id(acc),
|
||||
len(acc.calls),
|
||||
acc.total_cost_micros,
|
||||
)
|
||||
|
||||
|
||||
def _extract_cost_usd(
|
||||
kwargs: dict[str, Any],
|
||||
response_obj: Any,
|
||||
model: str,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
is_image: bool = False,
|
||||
) -> float:
|
||||
"""Best-effort USD cost extraction for a single LLM/image call.
|
||||
|
||||
Tries four sources in priority order and returns the first that
|
||||
yields a positive number; returns 0.0 if all four fail (the call
|
||||
will then debit nothing from the user's balance — fail-safe).
|
||||
|
||||
Sources:
|
||||
1. ``kwargs["response_cost"]`` — LiteLLM's standard callback
|
||||
field, populated for ``Router.acompletion`` since PR #12500.
|
||||
2. ``response_obj._hidden_params["response_cost"]`` — same value
|
||||
exposed on the response itself.
|
||||
3. ``litellm.completion_cost(completion_response=response_obj)``
|
||||
— recompute from the response and LiteLLM's pricing table.
|
||||
4. ``litellm.cost_per_token(model, prompt_tokens, completion_tokens)``
|
||||
— manual fallback for OpenRouter/custom-Azure models that
|
||||
only resolve via aliases registered by
|
||||
``pricing_registration`` at startup. **Skipped for image
|
||||
responses** — ``cost_per_token`` does not support ``ImageResponse``
|
||||
and would raise; the cost map for image-gen lives in different
|
||||
keys (``output_cost_per_image``) handled by ``completion_cost``.
|
||||
"""
|
||||
cost = kwargs.get("response_cost")
|
||||
if cost is not None:
|
||||
try:
|
||||
value = float(cost)
|
||||
except (TypeError, ValueError):
|
||||
value = 0.0
|
||||
if value > 0:
|
||||
return value
|
||||
|
||||
hidden = getattr(response_obj, "_hidden_params", None) or {}
|
||||
if isinstance(hidden, dict):
|
||||
cost = hidden.get("response_cost")
|
||||
if cost is not None:
|
||||
try:
|
||||
value = float(cost)
|
||||
except (TypeError, ValueError):
|
||||
value = 0.0
|
||||
if value > 0:
|
||||
return value
|
||||
|
||||
try:
|
||||
value = float(litellm.completion_cost(completion_response=response_obj))
|
||||
if value > 0:
|
||||
return value
|
||||
except Exception as exc:
|
||||
if is_image:
|
||||
# Image-gen path: OpenRouter's image responses can omit
|
||||
# ``usage.cost`` and LiteLLM's ``default_image_cost_calculator``
|
||||
# then *raises* (no cost map for OpenRouter image models).
|
||||
# Bail out with a warning rather than falling through to
|
||||
# cost_per_token (which is also incompatible with ImageResponse).
|
||||
logger.warning(
|
||||
"[TokenTracking] completion_cost failed for image model=%s "
|
||||
"(provider may have omitted usage.cost). Debiting 0. "
|
||||
"Cause: %s",
|
||||
model,
|
||||
exc,
|
||||
)
|
||||
return 0.0
|
||||
logger.debug(
|
||||
"[TokenTracking] completion_cost failed for model=%s: %s", model, exc
|
||||
)
|
||||
|
||||
if is_image:
|
||||
# Never call cost_per_token for ImageResponse — keys mismatch and
|
||||
# the function is documented chat-only.
|
||||
return 0.0
|
||||
|
||||
if model and (prompt_tokens > 0 or completion_tokens > 0):
|
||||
try:
|
||||
prompt_cost, completion_cost = litellm.cost_per_token(
|
||||
model=model,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
value = float(prompt_cost) + float(completion_cost)
|
||||
if value > 0:
|
||||
return value
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"[TokenTracking] cost_per_token failed for model=%s: %s", model, exc
|
||||
)
|
||||
|
||||
return 0.0
|
||||
|
||||
|
||||
class TokenTrackingCallback(CustomLogger):
|
||||
"""LiteLLM callback that captures token usage into the turn accumulator."""
|
||||
|
||||
|
|
@ -122,6 +293,13 @@ class TokenTrackingCallback(CustomLogger):
|
|||
)
|
||||
return
|
||||
|
||||
# Detect image generation responses — they have a different usage
|
||||
# shape (ImageUsage with input_tokens/output_tokens) and require a
|
||||
# different cost-extraction path. We probe by class name to avoid a
|
||||
# hard import dependency on litellm internals.
|
||||
response_cls = type(response_obj).__name__
|
||||
is_image = response_cls == "ImageResponse"
|
||||
|
||||
usage = getattr(response_obj, "usage", None)
|
||||
if not usage:
|
||||
logger.debug(
|
||||
|
|
@ -129,24 +307,66 @@ class TokenTrackingCallback(CustomLogger):
|
|||
)
|
||||
return
|
||||
|
||||
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
|
||||
completion_tokens = getattr(usage, "completion_tokens", 0) or 0
|
||||
total_tokens = getattr(usage, "total_tokens", 0) or 0
|
||||
if is_image:
|
||||
# ``ImageUsage`` exposes ``input_tokens`` / ``output_tokens``
|
||||
# (not prompt_tokens/completion_tokens). Several providers
|
||||
# populate only one or neither (e.g. OpenRouter's gpt-image-1
|
||||
# passes through `input_tokens` from the prompt but no
|
||||
# completion); fall through gracefully to 0.
|
||||
prompt_tokens = getattr(usage, "input_tokens", 0) or 0
|
||||
completion_tokens = getattr(usage, "output_tokens", 0) or 0
|
||||
total_tokens = (
|
||||
getattr(usage, "total_tokens", 0) or prompt_tokens + completion_tokens
|
||||
)
|
||||
call_kind = "image_generation"
|
||||
else:
|
||||
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
|
||||
completion_tokens = getattr(usage, "completion_tokens", 0) or 0
|
||||
total_tokens = getattr(usage, "total_tokens", 0) or 0
|
||||
call_kind = "chat"
|
||||
|
||||
model = kwargs.get("model", "unknown")
|
||||
|
||||
cost_usd = _extract_cost_usd(
|
||||
kwargs=kwargs,
|
||||
response_obj=response_obj,
|
||||
model=model,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
is_image=is_image,
|
||||
)
|
||||
cost_micros = round(cost_usd * 1_000_000) if cost_usd > 0 else 0
|
||||
|
||||
if cost_micros == 0 and (prompt_tokens > 0 or completion_tokens > 0):
|
||||
logger.warning(
|
||||
"[TokenTracking] No cost resolved for model=%s prompt=%d completion=%d "
|
||||
"kind=%s — debiting 0. Register pricing via pricing_registration or YAML "
|
||||
"input_cost_per_token/output_cost_per_token (or rely on response_cost "
|
||||
"for image generation).",
|
||||
model,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
call_kind,
|
||||
)
|
||||
|
||||
acc.add(
|
||||
model=model,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cost_micros=cost_micros,
|
||||
call_kind=call_kind,
|
||||
)
|
||||
logger.info(
|
||||
"[TokenTracking] Captured: model=%s prompt=%d completion=%d total=%d (accumulator now has %d calls)",
|
||||
"[TokenTracking] Captured: model=%s kind=%s prompt=%d completion=%d total=%d "
|
||||
"cost=$%.6f (%d micros) (accumulator now has %d calls)",
|
||||
model,
|
||||
call_kind,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
cost_usd,
|
||||
cost_micros,
|
||||
len(acc.calls),
|
||||
)
|
||||
|
||||
|
|
@ -168,6 +388,7 @@ async def record_token_usage(
|
|||
prompt_tokens: int = 0,
|
||||
completion_tokens: int = 0,
|
||||
total_tokens: int = 0,
|
||||
cost_micros: int = 0,
|
||||
model_breakdown: dict[str, Any] | None = None,
|
||||
call_details: dict[str, Any] | None = None,
|
||||
thread_id: int | None = None,
|
||||
|
|
@ -185,6 +406,7 @@ async def record_token_usage(
|
|||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cost_micros=cost_micros,
|
||||
model_breakdown=model_breakdown,
|
||||
call_details=call_details,
|
||||
thread_id=thread_id,
|
||||
|
|
@ -194,11 +416,12 @@ async def record_token_usage(
|
|||
)
|
||||
session.add(record)
|
||||
logger.debug(
|
||||
"[TokenTracking] recorded %s usage: prompt=%d completion=%d total=%d",
|
||||
"[TokenTracking] recorded %s usage: prompt=%d completion=%d total=%d cost_micros=%d",
|
||||
usage_type,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
cost_micros,
|
||||
)
|
||||
return record
|
||||
except Exception:
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@ from typing import Any
|
|||
|
||||
from litellm import Router
|
||||
|
||||
from app.services.provider_api_base import resolve_api_base
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VISION_AUTO_MODE_ID = 0
|
||||
|
|
@ -108,10 +110,11 @@ class VisionLLMRouterService:
|
|||
if not config.get("model_name") or not config.get("api_key"):
|
||||
return None
|
||||
|
||||
provider = config.get("provider", "").upper()
|
||||
if config.get("custom_provider"):
|
||||
model_string = f"{config['custom_provider']}/{config['model_name']}"
|
||||
provider_prefix = config["custom_provider"]
|
||||
model_string = f"{provider_prefix}/{config['model_name']}"
|
||||
else:
|
||||
provider = config.get("provider", "").upper()
|
||||
provider_prefix = VISION_PROVIDER_MAP.get(provider, provider.lower())
|
||||
model_string = f"{provider_prefix}/{config['model_name']}"
|
||||
|
||||
|
|
@ -120,8 +123,13 @@ class VisionLLMRouterService:
|
|||
"api_key": config.get("api_key"),
|
||||
}
|
||||
|
||||
if config.get("api_base"):
|
||||
litellm_params["api_base"] = config["api_base"]
|
||||
api_base = resolve_api_base(
|
||||
provider=provider,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=config.get("api_base"),
|
||||
)
|
||||
if api_base:
|
||||
litellm_params["api_base"] = api_base
|
||||
|
||||
if config.get("api_version"):
|
||||
litellm_params["api_version"] = config["api_version"]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue