mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-08 15:22:39 +02:00
Merge remote-tracking branch 'upstream/dev' into fix/memory-extraction
This commit is contained in:
commit
b981b51ab1
176 changed files with 20407 additions and 6258 deletions
|
|
@ -23,6 +23,7 @@ from fastapi import APIRouter, Depends
|
|||
from pydantic import BaseModel
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags
|
||||
from app.config import config
|
||||
from app.db import User
|
||||
from app.users import current_active_user
|
||||
|
||||
|
|
@ -58,10 +59,15 @@ class AgentFeatureFlagsRead(BaseModel):
|
|||
|
||||
enable_otel: bool
|
||||
|
||||
enable_desktop_local_filesystem: bool
|
||||
|
||||
@classmethod
|
||||
def from_flags(cls, flags: AgentFeatureFlags) -> AgentFeatureFlagsRead:
|
||||
# asdict() avoids missing-field bugs when AgentFeatureFlags grows.
|
||||
return cls(**asdict(flags))
|
||||
return cls(
|
||||
**asdict(flags),
|
||||
enable_desktop_local_filesystem=config.ENABLE_DESKTOP_LOCAL_FILESYSTEM,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/agent/flags", response_model=AgentFeatureFlagsRead)
|
||||
|
|
|
|||
|
|
@ -649,13 +649,9 @@ async def list_composio_drive_folders(
|
|||
"""
|
||||
List folders AND files in user's Google Drive via Composio.
|
||||
|
||||
Uses the same GoogleDriveClient / list_folder_contents path as the native
|
||||
connector, with Composio-sourced credentials. This means auth errors
|
||||
propagate identically (Google returns 401 → exception → auth_expired flag).
|
||||
Uses Composio's Google Drive tool execution path so managed OAuth tokens
|
||||
do not need to be exposed through connected account state.
|
||||
"""
|
||||
from app.connectors.google_drive import GoogleDriveClient, list_folder_contents
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
if not ComposioService.is_enabled():
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
|
|
@ -689,10 +685,37 @@ async def list_composio_drive_folders(
|
|||
detail="Composio connected account not found. Please reconnect the connector.",
|
||||
)
|
||||
|
||||
credentials = build_composio_credentials(composio_connected_account_id)
|
||||
drive_client = GoogleDriveClient(session, connector_id, credentials=credentials)
|
||||
service = ComposioService()
|
||||
entity_id = f"surfsense_{user.id}"
|
||||
items = []
|
||||
page_token = None
|
||||
error = None
|
||||
|
||||
items, error = await list_folder_contents(drive_client, parent_id=parent_id)
|
||||
while True:
|
||||
page_items, next_token, page_error = await service.get_drive_files(
|
||||
connected_account_id=composio_connected_account_id,
|
||||
entity_id=entity_id,
|
||||
folder_id=parent_id,
|
||||
page_token=page_token,
|
||||
page_size=100,
|
||||
)
|
||||
if page_error:
|
||||
error = page_error
|
||||
break
|
||||
|
||||
items.extend(page_items)
|
||||
if not next_token:
|
||||
break
|
||||
page_token = next_token
|
||||
|
||||
for item in items:
|
||||
item["isFolder"] = (
|
||||
item.get("mimeType") == "application/vnd.google-apps.folder"
|
||||
)
|
||||
|
||||
items.sort(
|
||||
key=lambda item: (not item["isFolder"], item.get("name", "").lower())
|
||||
)
|
||||
|
||||
if error:
|
||||
error_lower = error.lower()
|
||||
|
|
|
|||
|
|
@ -36,11 +36,17 @@ from app.schemas import (
|
|||
ImageGenerationListRead,
|
||||
ImageGenerationRead,
|
||||
)
|
||||
from app.services.billable_calls import (
|
||||
DEFAULT_IMAGE_RESERVE_MICROS,
|
||||
QuotaInsufficientError,
|
||||
billable_call,
|
||||
)
|
||||
from app.services.image_gen_router_service import (
|
||||
IMAGE_GEN_AUTO_MODE_ID,
|
||||
ImageGenRouterService,
|
||||
is_image_gen_auto_mode,
|
||||
)
|
||||
from app.services.provider_api_base import resolve_api_base
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
from app.utils.signed_image_urls import verify_image_token
|
||||
|
|
@ -82,14 +88,62 @@ def _get_global_image_gen_config(config_id: int) -> dict | None:
|
|||
return None
|
||||
|
||||
|
||||
def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str:
|
||||
"""Resolve the LiteLLM provider prefix used in model strings."""
|
||||
if custom_provider:
|
||||
return custom_provider
|
||||
return _PROVIDER_MAP.get(provider.upper(), provider.lower())
|
||||
|
||||
|
||||
def _build_model_string(
|
||||
provider: str, model_name: str, custom_provider: str | None
|
||||
) -> str:
|
||||
"""Build a litellm model string from provider + model_name."""
|
||||
if custom_provider:
|
||||
return f"{custom_provider}/{model_name}"
|
||||
prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower())
|
||||
return f"{prefix}/{model_name}"
|
||||
return f"{_resolve_provider_prefix(provider, custom_provider)}/{model_name}"
|
||||
|
||||
|
||||
async def _resolve_billing_for_image_gen(
|
||||
session: AsyncSession,
|
||||
config_id: int | None,
|
||||
search_space: SearchSpace,
|
||||
) -> tuple[str, str, int]:
|
||||
"""Resolve ``(billing_tier, base_model, reserve_micros)`` for a request.
|
||||
|
||||
The resolution mirrors ``_execute_image_generation``'s lookup tree but
|
||||
only extracts the fields needed for billing — we do this *before*
|
||||
``billable_call`` so the reservation is correctly sized for the
|
||||
config that will actually run, and so we don't open an
|
||||
``ImageGeneration`` row for a request that's about to 402.
|
||||
|
||||
User-owned (positive ID) BYOK configs are always free — they cost
|
||||
the user nothing on our side. Auto mode currently treats as free
|
||||
because the underlying router can dispatch to either premium or
|
||||
free YAML configs and we don't surface the resolved deployment up
|
||||
here yet. Bringing Auto under premium billing would require
|
||||
threading the chosen deployment back from ``ImageGenRouterService``.
|
||||
"""
|
||||
resolved_id = config_id
|
||||
if resolved_id is None:
|
||||
resolved_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID
|
||||
|
||||
if is_image_gen_auto_mode(resolved_id):
|
||||
return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS)
|
||||
|
||||
if resolved_id < 0:
|
||||
cfg = _get_global_image_gen_config(resolved_id) or {}
|
||||
billing_tier = str(cfg.get("billing_tier", "free")).lower()
|
||||
base_model = _build_model_string(
|
||||
cfg.get("provider", ""),
|
||||
cfg.get("model_name", ""),
|
||||
cfg.get("custom_provider"),
|
||||
)
|
||||
reserve_micros = int(
|
||||
cfg.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS
|
||||
)
|
||||
return (billing_tier, base_model, reserve_micros)
|
||||
|
||||
# Positive ID = user-owned BYOK image-gen config — always free.
|
||||
return ("free", "user_byok", DEFAULT_IMAGE_RESERVE_MICROS)
|
||||
|
||||
|
||||
async def _execute_image_generation(
|
||||
|
|
@ -138,12 +192,18 @@ async def _execute_image_generation(
|
|||
if not cfg:
|
||||
raise ValueError(f"Global image generation config {config_id} not found")
|
||||
|
||||
model_string = _build_model_string(
|
||||
cfg.get("provider", ""), cfg["model_name"], cfg.get("custom_provider")
|
||||
provider_prefix = _resolve_provider_prefix(
|
||||
cfg.get("provider", ""), cfg.get("custom_provider")
|
||||
)
|
||||
model_string = f"{provider_prefix}/{cfg['model_name']}"
|
||||
gen_kwargs["api_key"] = cfg.get("api_key")
|
||||
if cfg.get("api_base"):
|
||||
gen_kwargs["api_base"] = cfg["api_base"]
|
||||
api_base = resolve_api_base(
|
||||
provider=cfg.get("provider"),
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=cfg.get("api_base"),
|
||||
)
|
||||
if api_base:
|
||||
gen_kwargs["api_base"] = api_base
|
||||
if cfg.get("api_version"):
|
||||
gen_kwargs["api_version"] = cfg["api_version"]
|
||||
if cfg.get("litellm_params"):
|
||||
|
|
@ -165,12 +225,18 @@ async def _execute_image_generation(
|
|||
if not db_cfg:
|
||||
raise ValueError(f"Image generation config {config_id} not found")
|
||||
|
||||
model_string = _build_model_string(
|
||||
db_cfg.provider.value, db_cfg.model_name, db_cfg.custom_provider
|
||||
provider_prefix = _resolve_provider_prefix(
|
||||
db_cfg.provider.value, db_cfg.custom_provider
|
||||
)
|
||||
model_string = f"{provider_prefix}/{db_cfg.model_name}"
|
||||
gen_kwargs["api_key"] = db_cfg.api_key
|
||||
if db_cfg.api_base:
|
||||
gen_kwargs["api_base"] = db_cfg.api_base
|
||||
api_base = resolve_api_base(
|
||||
provider=db_cfg.provider.value,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=db_cfg.api_base,
|
||||
)
|
||||
if api_base:
|
||||
gen_kwargs["api_base"] = api_base
|
||||
if db_cfg.api_version:
|
||||
gen_kwargs["api_version"] = db_cfg.api_version
|
||||
if db_cfg.litellm_params:
|
||||
|
|
@ -225,10 +291,15 @@ async def get_global_image_gen_configs(
|
|||
"litellm_params": {},
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
# Auto mode currently treated as free until per-deployment
|
||||
# billing-tier surfacing lands (see _resolve_billing_for_image_gen).
|
||||
"billing_tier": "free",
|
||||
"is_premium": False,
|
||||
}
|
||||
)
|
||||
|
||||
for cfg in global_configs:
|
||||
billing_tier = str(cfg.get("billing_tier", "free")).lower()
|
||||
safe_configs.append(
|
||||
{
|
||||
"id": cfg.get("id"),
|
||||
|
|
@ -241,6 +312,12 @@ async def get_global_image_gen_configs(
|
|||
"api_version": cfg.get("api_version") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
"billing_tier": billing_tier,
|
||||
# Mirror chat (``new_llm_config_routes``) so the new-chat
|
||||
# selector's premium badge logic keys off the same
|
||||
# field across chat / image / vision tabs.
|
||||
"is_premium": billing_tier == "premium",
|
||||
"quota_reserve_micros": cfg.get("quota_reserve_micros"),
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -454,7 +531,26 @@ async def create_image_generation(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Create and execute an image generation request."""
|
||||
"""Create and execute an image generation request.
|
||||
|
||||
Premium configs are gated by the user's shared premium credit pool.
|
||||
The flow is:
|
||||
|
||||
1. Permission check + load the search space (cheap, no provider call).
|
||||
2. Resolve which config will run so we know its billing tier and the
|
||||
worst-case reservation size *before* opening any DB rows.
|
||||
3. Wrap the entire ImageGeneration row insert + provider call in
|
||||
``billable_call``. If quota is denied, ``billable_call`` raises
|
||||
``QuotaInsufficientError`` *before* we flush a row, which we
|
||||
translate to HTTP 402 (no orphaned rows on the user's account,
|
||||
no inserted error rows for "you ran out of credit").
|
||||
4. On success, the actual ``response_cost`` flows through the
|
||||
LiteLLM callback into the accumulator, and ``billable_call``
|
||||
finalizes the debit at exit. Inner ``try/except`` still catches
|
||||
provider errors and stores them on ``error_message`` (HTTP 200
|
||||
with ``error_message`` set is preserved for failed-but-not-quota
|
||||
scenarios — clients already know how to surface those).
|
||||
"""
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
|
|
@ -471,33 +567,70 @@ async def create_image_generation(
|
|||
if not search_space:
|
||||
raise HTTPException(status_code=404, detail="Search space not found")
|
||||
|
||||
db_image_gen = ImageGeneration(
|
||||
prompt=data.prompt,
|
||||
model=data.model,
|
||||
n=data.n,
|
||||
quality=data.quality,
|
||||
size=data.size,
|
||||
style=data.style,
|
||||
response_format=data.response_format,
|
||||
image_generation_config_id=data.image_generation_config_id,
|
||||
search_space_id=data.search_space_id,
|
||||
created_by_id=user.id,
|
||||
billing_tier, base_model, reserve_micros = await _resolve_billing_for_image_gen(
|
||||
session, data.image_generation_config_id, search_space
|
||||
)
|
||||
session.add(db_image_gen)
|
||||
await session.flush()
|
||||
|
||||
try:
|
||||
await _execute_image_generation(session, db_image_gen, search_space)
|
||||
except Exception as e:
|
||||
logger.exception("Image generation call failed")
|
||||
db_image_gen.error_message = str(e)
|
||||
# billable_call runs OUTSIDE the inner try/except so QuotaInsufficientError
|
||||
# propagates to the outer ``except QuotaInsufficientError`` handler
|
||||
# below as HTTP 402 — it is intentionally NOT swallowed into
|
||||
# ``error_message`` because that would (1) imply a successful row
|
||||
# exists when none does, and (2) return HTTP 200 to a client
|
||||
# whose request was actively *denied* (issue K).
|
||||
async with billable_call(
|
||||
user_id=search_space.user_id,
|
||||
search_space_id=data.search_space_id,
|
||||
billing_tier=billing_tier,
|
||||
base_model=base_model,
|
||||
quota_reserve_micros_override=reserve_micros,
|
||||
usage_type="image_generation",
|
||||
call_details={"model": base_model, "prompt": data.prompt[:100]},
|
||||
):
|
||||
db_image_gen = ImageGeneration(
|
||||
prompt=data.prompt,
|
||||
model=data.model,
|
||||
n=data.n,
|
||||
quality=data.quality,
|
||||
size=data.size,
|
||||
style=data.style,
|
||||
response_format=data.response_format,
|
||||
image_generation_config_id=data.image_generation_config_id,
|
||||
search_space_id=data.search_space_id,
|
||||
created_by_id=user.id,
|
||||
)
|
||||
session.add(db_image_gen)
|
||||
await session.flush()
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(db_image_gen)
|
||||
return db_image_gen
|
||||
try:
|
||||
await _execute_image_generation(session, db_image_gen, search_space)
|
||||
except Exception as e:
|
||||
logger.exception("Image generation call failed")
|
||||
db_image_gen.error_message = str(e)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(db_image_gen)
|
||||
return db_image_gen
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except QuotaInsufficientError as exc:
|
||||
# The user's premium credit pool is empty. No DB row is created
|
||||
# because ``billable_call`` denies before yielding (issue K).
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail={
|
||||
"error_code": "premium_quota_exhausted",
|
||||
"usage_type": exc.usage_type,
|
||||
"used_micros": exc.used_micros,
|
||||
"limit_micros": exc.limit_micros,
|
||||
"remaining_micros": exc.remaining_micros,
|
||||
"message": (
|
||||
"Out of premium credits for image generation. "
|
||||
"Purchase additional credits or switch to a free model."
|
||||
),
|
||||
},
|
||||
) from exc
|
||||
except SQLAlchemyError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
|
|
|
|||
|
|
@ -1366,7 +1366,11 @@ async def append_message(
|
|||
# flush assigns the PK/defaults without a round-trip SELECT
|
||||
await session.flush()
|
||||
|
||||
# Persist token usage if provided (for assistant messages)
|
||||
# Persist token usage if provided (for assistant messages).
|
||||
# ``cost_micros`` is the provider USD cost reported by LiteLLM,
|
||||
# forwarded by the FE through the appendMessage round-trip so
|
||||
# the historical TokenUsage row matches the credit debit applied
|
||||
# at finalize time.
|
||||
token_usage_data = raw_body.get("token_usage")
|
||||
if token_usage_data and message_role == NewChatMessageRole.ASSISTANT:
|
||||
await record_token_usage(
|
||||
|
|
@ -1377,6 +1381,7 @@ async def append_message(
|
|||
prompt_tokens=token_usage_data.get("prompt_tokens", 0),
|
||||
completion_tokens=token_usage_data.get("completion_tokens", 0),
|
||||
total_tokens=token_usage_data.get("total_tokens", 0),
|
||||
cost_micros=token_usage_data.get("cost_micros", 0),
|
||||
model_breakdown=token_usage_data.get("usage"),
|
||||
call_details=token_usage_data.get("call_details"),
|
||||
thread_id=thread_id,
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from app.schemas import (
|
|||
NewLLMConfigUpdate,
|
||||
)
|
||||
from app.services.llm_service import validate_llm_config
|
||||
from app.services.provider_capabilities import derive_supports_image_input
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
|
|
@ -36,6 +37,39 @@ router = APIRouter()
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _serialize_byok_config(config: NewLLMConfig) -> NewLLMConfigRead:
|
||||
"""Augment a BYOK chat config row with the derived ``supports_image_input``.
|
||||
|
||||
There is no DB column for ``supports_image_input`` — the value is
|
||||
resolved at the API boundary from LiteLLM's authoritative model map
|
||||
(default-allow on unknown). Returning ``NewLLMConfigRead`` here keeps
|
||||
the response shape consistent across list / detail / create / update
|
||||
endpoints without having to remember to set the field at every call
|
||||
site.
|
||||
"""
|
||||
provider_value = (
|
||||
config.provider.value
|
||||
if hasattr(config.provider, "value")
|
||||
else str(config.provider)
|
||||
)
|
||||
litellm_params = config.litellm_params or {}
|
||||
base_model = (
|
||||
litellm_params.get("base_model") if isinstance(litellm_params, dict) else None
|
||||
)
|
||||
supports_image_input = derive_supports_image_input(
|
||||
provider=provider_value,
|
||||
model_name=config.model_name,
|
||||
base_model=base_model,
|
||||
custom_provider=config.custom_provider,
|
||||
)
|
||||
# ``model_validate`` runs the Pydantic conversion using the ORM
|
||||
# attribute access path enabled by ``ConfigDict(from_attributes=True)``,
|
||||
# then we layer the derived field on. ``model_copy(update=...)`` keeps
|
||||
# the surface immutable from the caller's perspective.
|
||||
base_read = NewLLMConfigRead.model_validate(config)
|
||||
return base_read.model_copy(update={"supports_image_input": supports_image_input})
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Global Configs Routes
|
||||
# =============================================================================
|
||||
|
|
@ -84,11 +118,41 @@ async def get_global_new_llm_configs(
|
|||
"seo_title": None,
|
||||
"seo_description": None,
|
||||
"quota_reserve_tokens": None,
|
||||
# Auto routes across the configured pool, which usually
|
||||
# includes at least one vision-capable deployment, so
|
||||
# treat Auto as image-capable. The router itself will
|
||||
# still pick a vision-capable deployment for messages
|
||||
# carrying image_url blocks (LiteLLM Router falls back
|
||||
# on ``404`` per its ``allowed_fails`` policy).
|
||||
"supports_image_input": True,
|
||||
}
|
||||
)
|
||||
|
||||
# Add individual global configs
|
||||
for cfg in global_configs:
|
||||
# Capability resolution: explicit value (YAML override or OR
|
||||
# `_supports_image_input(model)` payload baked in by the
|
||||
# OpenRouter integration service) wins. Fall back to the
|
||||
# LiteLLM-driven helper which default-allows on unknown so
|
||||
# we don't hide vision-capable models that happen to lack a
|
||||
# YAML annotation. The streaming task safety net is the
|
||||
# only place a False ever blocks.
|
||||
if "supports_image_input" in cfg:
|
||||
supports_image_input = bool(cfg.get("supports_image_input"))
|
||||
else:
|
||||
cfg_litellm_params = cfg.get("litellm_params") or {}
|
||||
cfg_base_model = (
|
||||
cfg_litellm_params.get("base_model")
|
||||
if isinstance(cfg_litellm_params, dict)
|
||||
else None
|
||||
)
|
||||
supports_image_input = derive_supports_image_input(
|
||||
provider=cfg.get("provider"),
|
||||
model_name=cfg.get("model_name"),
|
||||
base_model=cfg_base_model,
|
||||
custom_provider=cfg.get("custom_provider"),
|
||||
)
|
||||
|
||||
safe_config = {
|
||||
"id": cfg.get("id"),
|
||||
"name": cfg.get("name"),
|
||||
|
|
@ -113,6 +177,7 @@ async def get_global_new_llm_configs(
|
|||
"seo_title": cfg.get("seo_title"),
|
||||
"seo_description": cfg.get("seo_description"),
|
||||
"quota_reserve_tokens": cfg.get("quota_reserve_tokens"),
|
||||
"supports_image_input": supports_image_input,
|
||||
}
|
||||
safe_configs.append(safe_config)
|
||||
|
||||
|
|
@ -171,7 +236,7 @@ async def create_new_llm_config(
|
|||
await session.commit()
|
||||
await session.refresh(db_config)
|
||||
|
||||
return db_config
|
||||
return _serialize_byok_config(db_config)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
@ -213,7 +278,7 @@ async def list_new_llm_configs(
|
|||
.limit(limit)
|
||||
)
|
||||
|
||||
return result.scalars().all()
|
||||
return [_serialize_byok_config(cfg) for cfg in result.scalars().all()]
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
@ -268,7 +333,7 @@ async def get_new_llm_config(
|
|||
"You don't have permission to view LLM configurations in this search space",
|
||||
)
|
||||
|
||||
return config
|
||||
return _serialize_byok_config(config)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
@ -360,7 +425,7 @@ async def update_new_llm_config(
|
|||
await session.commit()
|
||||
await session.refresh(config)
|
||||
|
||||
return config
|
||||
return _serialize_byok_config(config)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -591,6 +591,7 @@ async def _get_image_gen_config_by_id(
|
|||
"model_name": "auto",
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
"billing_tier": "free",
|
||||
}
|
||||
|
||||
if config_id < 0:
|
||||
|
|
@ -607,6 +608,7 @@ async def _get_image_gen_config_by_id(
|
|||
"api_version": cfg.get("api_version") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
"billing_tier": cfg.get("billing_tier", "free"),
|
||||
}
|
||||
return None
|
||||
|
||||
|
|
@ -649,6 +651,7 @@ async def _get_vision_llm_config_by_id(
|
|||
"model_name": "auto",
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
"billing_tier": "free",
|
||||
}
|
||||
|
||||
if config_id < 0:
|
||||
|
|
@ -665,6 +668,7 @@ async def _get_vision_llm_config_by_id(
|
|||
"api_version": cfg.get("api_version") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
"billing_tier": cfg.get("billing_tier", "free"),
|
||||
}
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -251,9 +251,16 @@ async def _fulfill_completed_token_purchase(
|
|||
metadata = _get_metadata(checkout_session)
|
||||
user_id = metadata.get("user_id")
|
||||
quantity = int(metadata.get("quantity", "0"))
|
||||
tokens_per_unit = int(metadata.get("tokens_per_unit", "0"))
|
||||
# Read the new metadata key first, fall back to the legacy one so
|
||||
# in-flight checkout sessions created before the cost-credits
|
||||
# release still fulfil correctly (the unit is numerically the
|
||||
# same: $1 buys 1_000_000 micro-USD == 1_000_000 tokens).
|
||||
credit_micros_per_unit = int(
|
||||
metadata.get("credit_micros_per_unit")
|
||||
or metadata.get("tokens_per_unit", "0")
|
||||
)
|
||||
|
||||
if not user_id or quantity <= 0 or tokens_per_unit <= 0:
|
||||
if not user_id or quantity <= 0 or credit_micros_per_unit <= 0:
|
||||
logger.error(
|
||||
"Skipping token fulfillment for session %s: incomplete metadata %s",
|
||||
checkout_session_id,
|
||||
|
|
@ -268,7 +275,7 @@ async def _fulfill_completed_token_purchase(
|
|||
getattr(checkout_session, "payment_intent", None)
|
||||
),
|
||||
quantity=quantity,
|
||||
tokens_granted=quantity * tokens_per_unit,
|
||||
credit_micros_granted=quantity * credit_micros_per_unit,
|
||||
amount_total=getattr(checkout_session, "amount_total", None),
|
||||
currency=getattr(checkout_session, "currency", None),
|
||||
status=PremiumTokenPurchaseStatus.PENDING,
|
||||
|
|
@ -303,9 +310,14 @@ async def _fulfill_completed_token_purchase(
|
|||
purchase.stripe_payment_intent_id = _normalize_optional_string(
|
||||
getattr(checkout_session, "payment_intent", None)
|
||||
)
|
||||
user.premium_tokens_limit = (
|
||||
max(user.premium_tokens_used, user.premium_tokens_limit)
|
||||
+ purchase.tokens_granted
|
||||
# Top up the user's credit balance by the granted micro-USD amount.
|
||||
# ``max(used, limit)`` clamps the case where the legacy code wrote a
|
||||
# used value above the limit (e.g. underbilling rounding) so adding
|
||||
# ``credit_micros_granted`` always lifts the limit by the full pack
|
||||
# size rather than disappearing into past overuse.
|
||||
user.premium_credit_micros_limit = (
|
||||
max(user.premium_credit_micros_used, user.premium_credit_micros_limit)
|
||||
+ purchase.credit_micros_granted
|
||||
)
|
||||
|
||||
await db_session.commit()
|
||||
|
|
@ -532,12 +544,18 @@ async def create_token_checkout_session(
|
|||
user: User = Depends(current_active_user),
|
||||
db_session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Create a Stripe Checkout Session for buying premium token packs."""
|
||||
"""Create a Stripe Checkout Session for buying premium credit packs.
|
||||
|
||||
Each pack grants ``STRIPE_CREDIT_MICROS_PER_UNIT`` micro-USD of
|
||||
credit (default 1_000_000 = $1.00). The user's balance is debited
|
||||
at the actual provider cost reported by LiteLLM at finalize time,
|
||||
so $1 of credit always buys $1 worth of provider usage at cost.
|
||||
"""
|
||||
_ensure_token_buying_enabled()
|
||||
stripe_client = get_stripe_client()
|
||||
price_id = _get_required_token_price_id()
|
||||
success_url, cancel_url = _get_token_checkout_urls(body.search_space_id)
|
||||
tokens_granted = body.quantity * config.STRIPE_TOKENS_PER_UNIT
|
||||
credit_micros_granted = body.quantity * config.STRIPE_CREDIT_MICROS_PER_UNIT
|
||||
|
||||
try:
|
||||
checkout_session = stripe_client.v1.checkout.sessions.create(
|
||||
|
|
@ -556,8 +574,8 @@ async def create_token_checkout_session(
|
|||
"metadata": {
|
||||
"user_id": str(user.id),
|
||||
"quantity": str(body.quantity),
|
||||
"tokens_per_unit": str(config.STRIPE_TOKENS_PER_UNIT),
|
||||
"purchase_type": "premium_tokens",
|
||||
"credit_micros_per_unit": str(config.STRIPE_CREDIT_MICROS_PER_UNIT),
|
||||
"purchase_type": "premium_credit",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
|
@ -583,7 +601,7 @@ async def create_token_checkout_session(
|
|||
getattr(checkout_session, "payment_intent", None)
|
||||
),
|
||||
quantity=body.quantity,
|
||||
tokens_granted=tokens_granted,
|
||||
credit_micros_granted=credit_micros_granted,
|
||||
amount_total=getattr(checkout_session, "amount_total", None),
|
||||
currency=getattr(checkout_session, "currency", None),
|
||||
status=PremiumTokenPurchaseStatus.PENDING,
|
||||
|
|
@ -598,14 +616,19 @@ async def create_token_checkout_session(
|
|||
async def get_token_status(
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Return token-buying availability and current premium quota for frontend."""
|
||||
used = user.premium_tokens_used
|
||||
limit = user.premium_tokens_limit
|
||||
"""Return token-buying availability and current premium credit quota for frontend.
|
||||
|
||||
Values are in micro-USD (1_000_000 = $1.00); the FE divides by 1M
|
||||
when displaying. The route name is preserved for back-compat with
|
||||
pinned client deployments.
|
||||
"""
|
||||
used = user.premium_credit_micros_used
|
||||
limit = user.premium_credit_micros_limit
|
||||
return TokenStripeStatusResponse(
|
||||
token_buying_enabled=config.STRIPE_TOKEN_BUYING_ENABLED,
|
||||
premium_tokens_used=used,
|
||||
premium_tokens_limit=limit,
|
||||
premium_tokens_remaining=max(0, limit - used),
|
||||
premium_credit_micros_used=used,
|
||||
premium_credit_micros_limit=limit,
|
||||
premium_credit_micros_remaining=max(0, limit - used),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -82,10 +82,15 @@ async def get_global_vision_llm_configs(
|
|||
"litellm_params": {},
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
# Auto mode treated as free until per-deployment billing-tier
|
||||
# surfacing lands; see ``get_vision_llm`` for parity.
|
||||
"billing_tier": "free",
|
||||
"is_premium": False,
|
||||
}
|
||||
)
|
||||
|
||||
for cfg in global_configs:
|
||||
billing_tier = str(cfg.get("billing_tier", "free")).lower()
|
||||
safe_configs.append(
|
||||
{
|
||||
"id": cfg.get("id"),
|
||||
|
|
@ -98,6 +103,14 @@ async def get_global_vision_llm_configs(
|
|||
"api_version": cfg.get("api_version") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
"billing_tier": billing_tier,
|
||||
# Mirror chat (``new_llm_config_routes``) so the new-chat
|
||||
# selector's premium badge logic keys off the same
|
||||
# field across chat / image / vision tabs.
|
||||
"is_premium": billing_tier == "premium",
|
||||
"quota_reserve_tokens": cfg.get("quota_reserve_tokens"),
|
||||
"input_cost_per_token": cfg.get("input_cost_per_token"),
|
||||
"output_cost_per_token": cfg.get("output_cost_per_token"),
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue