mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-07 23:02:39 +02:00
Merge upstream/dev into feature/multi-agent
This commit is contained in:
commit
5119915f4f
278 changed files with 34669 additions and 8970 deletions
|
|
@ -23,6 +23,7 @@ from fastapi import APIRouter, Depends
|
|||
from pydantic import BaseModel
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags
|
||||
from app.config import config
|
||||
from app.db import User
|
||||
from app.users import current_active_user
|
||||
|
||||
|
|
@ -58,10 +59,15 @@ class AgentFeatureFlagsRead(BaseModel):
|
|||
|
||||
enable_otel: bool
|
||||
|
||||
enable_desktop_local_filesystem: bool
|
||||
|
||||
@classmethod
|
||||
def from_flags(cls, flags: AgentFeatureFlags) -> AgentFeatureFlagsRead:
|
||||
# asdict() avoids missing-field bugs when AgentFeatureFlags grows.
|
||||
return cls(**asdict(flags))
|
||||
return cls(
|
||||
**asdict(flags),
|
||||
enable_desktop_local_filesystem=config.ENABLE_DESKTOP_LOCAL_FILESYSTEM,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/agent/flags", response_model=AgentFeatureFlagsRead)
|
||||
|
|
|
|||
|
|
@ -649,13 +649,9 @@ async def list_composio_drive_folders(
|
|||
"""
|
||||
List folders AND files in user's Google Drive via Composio.
|
||||
|
||||
Uses the same GoogleDriveClient / list_folder_contents path as the native
|
||||
connector, with Composio-sourced credentials. This means auth errors
|
||||
propagate identically (Google returns 401 → exception → auth_expired flag).
|
||||
Uses Composio's Google Drive tool execution path so managed OAuth tokens
|
||||
do not need to be exposed through connected account state.
|
||||
"""
|
||||
from app.connectors.google_drive import GoogleDriveClient, list_folder_contents
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
if not ComposioService.is_enabled():
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
|
|
@ -689,10 +685,37 @@ async def list_composio_drive_folders(
|
|||
detail="Composio connected account not found. Please reconnect the connector.",
|
||||
)
|
||||
|
||||
credentials = build_composio_credentials(composio_connected_account_id)
|
||||
drive_client = GoogleDriveClient(session, connector_id, credentials=credentials)
|
||||
service = ComposioService()
|
||||
entity_id = f"surfsense_{user.id}"
|
||||
items = []
|
||||
page_token = None
|
||||
error = None
|
||||
|
||||
items, error = await list_folder_contents(drive_client, parent_id=parent_id)
|
||||
while True:
|
||||
page_items, next_token, page_error = await service.get_drive_files(
|
||||
connected_account_id=composio_connected_account_id,
|
||||
entity_id=entity_id,
|
||||
folder_id=parent_id,
|
||||
page_token=page_token,
|
||||
page_size=100,
|
||||
)
|
||||
if page_error:
|
||||
error = page_error
|
||||
break
|
||||
|
||||
items.extend(page_items)
|
||||
if not next_token:
|
||||
break
|
||||
page_token = next_token
|
||||
|
||||
for item in items:
|
||||
item["isFolder"] = (
|
||||
item.get("mimeType") == "application/vnd.google-apps.folder"
|
||||
)
|
||||
|
||||
items.sort(
|
||||
key=lambda item: (not item["isFolder"], item.get("name", "").lower())
|
||||
)
|
||||
|
||||
if error:
|
||||
error_lower = error.lower()
|
||||
|
|
|
|||
|
|
@ -745,6 +745,51 @@ async def search_document_titles(
|
|||
) from e
|
||||
|
||||
|
||||
@router.get("/documents/by-virtual-path", response_model=DocumentTitleRead)
|
||||
async def get_document_by_virtual_path(
|
||||
search_space_id: int,
|
||||
virtual_path: str,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Resolve a knowledge-base document id by exact virtual path."""
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
search_space_id,
|
||||
Permission.DOCUMENTS_READ.value,
|
||||
"You don't have permission to read documents in this search space",
|
||||
)
|
||||
|
||||
result = await session.execute(
|
||||
select(
|
||||
Document.id,
|
||||
Document.title,
|
||||
Document.document_type,
|
||||
).filter(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.document_metadata["virtual_path"].as_string() == virtual_path,
|
||||
)
|
||||
)
|
||||
row = result.first()
|
||||
if row is None:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
return DocumentTitleRead(
|
||||
id=row.id,
|
||||
title=row.title,
|
||||
document_type=row.document_type,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to resolve document by virtual path: {e!s}",
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/documents/status", response_model=DocumentStatusBatchResponse)
|
||||
async def get_documents_status(
|
||||
search_space_id: int,
|
||||
|
|
|
|||
|
|
@ -36,11 +36,17 @@ from app.schemas import (
|
|||
ImageGenerationListRead,
|
||||
ImageGenerationRead,
|
||||
)
|
||||
from app.services.billable_calls import (
|
||||
DEFAULT_IMAGE_RESERVE_MICROS,
|
||||
QuotaInsufficientError,
|
||||
billable_call,
|
||||
)
|
||||
from app.services.image_gen_router_service import (
|
||||
IMAGE_GEN_AUTO_MODE_ID,
|
||||
ImageGenRouterService,
|
||||
is_image_gen_auto_mode,
|
||||
)
|
||||
from app.services.provider_api_base import resolve_api_base
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
from app.utils.signed_image_urls import verify_image_token
|
||||
|
|
@ -82,14 +88,62 @@ def _get_global_image_gen_config(config_id: int) -> dict | None:
|
|||
return None
|
||||
|
||||
|
||||
def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str:
|
||||
"""Resolve the LiteLLM provider prefix used in model strings."""
|
||||
if custom_provider:
|
||||
return custom_provider
|
||||
return _PROVIDER_MAP.get(provider.upper(), provider.lower())
|
||||
|
||||
|
||||
def _build_model_string(
|
||||
provider: str, model_name: str, custom_provider: str | None
|
||||
) -> str:
|
||||
"""Build a litellm model string from provider + model_name."""
|
||||
if custom_provider:
|
||||
return f"{custom_provider}/{model_name}"
|
||||
prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower())
|
||||
return f"{prefix}/{model_name}"
|
||||
return f"{_resolve_provider_prefix(provider, custom_provider)}/{model_name}"
|
||||
|
||||
|
||||
async def _resolve_billing_for_image_gen(
|
||||
session: AsyncSession,
|
||||
config_id: int | None,
|
||||
search_space: SearchSpace,
|
||||
) -> tuple[str, str, int]:
|
||||
"""Resolve ``(billing_tier, base_model, reserve_micros)`` for a request.
|
||||
|
||||
The resolution mirrors ``_execute_image_generation``'s lookup tree but
|
||||
only extracts the fields needed for billing — we do this *before*
|
||||
``billable_call`` so the reservation is correctly sized for the
|
||||
config that will actually run, and so we don't open an
|
||||
``ImageGeneration`` row for a request that's about to 402.
|
||||
|
||||
User-owned (positive ID) BYOK configs are always free — they cost
|
||||
the user nothing on our side. Auto mode currently treats as free
|
||||
because the underlying router can dispatch to either premium or
|
||||
free YAML configs and we don't surface the resolved deployment up
|
||||
here yet. Bringing Auto under premium billing would require
|
||||
threading the chosen deployment back from ``ImageGenRouterService``.
|
||||
"""
|
||||
resolved_id = config_id
|
||||
if resolved_id is None:
|
||||
resolved_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID
|
||||
|
||||
if is_image_gen_auto_mode(resolved_id):
|
||||
return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS)
|
||||
|
||||
if resolved_id < 0:
|
||||
cfg = _get_global_image_gen_config(resolved_id) or {}
|
||||
billing_tier = str(cfg.get("billing_tier", "free")).lower()
|
||||
base_model = _build_model_string(
|
||||
cfg.get("provider", ""),
|
||||
cfg.get("model_name", ""),
|
||||
cfg.get("custom_provider"),
|
||||
)
|
||||
reserve_micros = int(
|
||||
cfg.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS
|
||||
)
|
||||
return (billing_tier, base_model, reserve_micros)
|
||||
|
||||
# Positive ID = user-owned BYOK image-gen config — always free.
|
||||
return ("free", "user_byok", DEFAULT_IMAGE_RESERVE_MICROS)
|
||||
|
||||
|
||||
async def _execute_image_generation(
|
||||
|
|
@ -138,12 +192,18 @@ async def _execute_image_generation(
|
|||
if not cfg:
|
||||
raise ValueError(f"Global image generation config {config_id} not found")
|
||||
|
||||
model_string = _build_model_string(
|
||||
cfg.get("provider", ""), cfg["model_name"], cfg.get("custom_provider")
|
||||
provider_prefix = _resolve_provider_prefix(
|
||||
cfg.get("provider", ""), cfg.get("custom_provider")
|
||||
)
|
||||
model_string = f"{provider_prefix}/{cfg['model_name']}"
|
||||
gen_kwargs["api_key"] = cfg.get("api_key")
|
||||
if cfg.get("api_base"):
|
||||
gen_kwargs["api_base"] = cfg["api_base"]
|
||||
api_base = resolve_api_base(
|
||||
provider=cfg.get("provider"),
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=cfg.get("api_base"),
|
||||
)
|
||||
if api_base:
|
||||
gen_kwargs["api_base"] = api_base
|
||||
if cfg.get("api_version"):
|
||||
gen_kwargs["api_version"] = cfg["api_version"]
|
||||
if cfg.get("litellm_params"):
|
||||
|
|
@ -165,12 +225,18 @@ async def _execute_image_generation(
|
|||
if not db_cfg:
|
||||
raise ValueError(f"Image generation config {config_id} not found")
|
||||
|
||||
model_string = _build_model_string(
|
||||
db_cfg.provider.value, db_cfg.model_name, db_cfg.custom_provider
|
||||
provider_prefix = _resolve_provider_prefix(
|
||||
db_cfg.provider.value, db_cfg.custom_provider
|
||||
)
|
||||
model_string = f"{provider_prefix}/{db_cfg.model_name}"
|
||||
gen_kwargs["api_key"] = db_cfg.api_key
|
||||
if db_cfg.api_base:
|
||||
gen_kwargs["api_base"] = db_cfg.api_base
|
||||
api_base = resolve_api_base(
|
||||
provider=db_cfg.provider.value,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=db_cfg.api_base,
|
||||
)
|
||||
if api_base:
|
||||
gen_kwargs["api_base"] = api_base
|
||||
if db_cfg.api_version:
|
||||
gen_kwargs["api_version"] = db_cfg.api_version
|
||||
if db_cfg.litellm_params:
|
||||
|
|
@ -225,10 +291,15 @@ async def get_global_image_gen_configs(
|
|||
"litellm_params": {},
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
# Auto mode currently treated as free until per-deployment
|
||||
# billing-tier surfacing lands (see _resolve_billing_for_image_gen).
|
||||
"billing_tier": "free",
|
||||
"is_premium": False,
|
||||
}
|
||||
)
|
||||
|
||||
for cfg in global_configs:
|
||||
billing_tier = str(cfg.get("billing_tier", "free")).lower()
|
||||
safe_configs.append(
|
||||
{
|
||||
"id": cfg.get("id"),
|
||||
|
|
@ -241,6 +312,12 @@ async def get_global_image_gen_configs(
|
|||
"api_version": cfg.get("api_version") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
"billing_tier": billing_tier,
|
||||
# Mirror chat (``new_llm_config_routes``) so the new-chat
|
||||
# selector's premium badge logic keys off the same
|
||||
# field across chat / image / vision tabs.
|
||||
"is_premium": billing_tier == "premium",
|
||||
"quota_reserve_micros": cfg.get("quota_reserve_micros"),
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -454,7 +531,26 @@ async def create_image_generation(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Create and execute an image generation request."""
|
||||
"""Create and execute an image generation request.
|
||||
|
||||
Premium configs are gated by the user's shared premium credit pool.
|
||||
The flow is:
|
||||
|
||||
1. Permission check + load the search space (cheap, no provider call).
|
||||
2. Resolve which config will run so we know its billing tier and the
|
||||
worst-case reservation size *before* opening any DB rows.
|
||||
3. Wrap the entire ImageGeneration row insert + provider call in
|
||||
``billable_call``. If quota is denied, ``billable_call`` raises
|
||||
``QuotaInsufficientError`` *before* we flush a row, which we
|
||||
translate to HTTP 402 (no orphaned rows on the user's account,
|
||||
no inserted error rows for "you ran out of credit").
|
||||
4. On success, the actual ``response_cost`` flows through the
|
||||
LiteLLM callback into the accumulator, and ``billable_call``
|
||||
finalizes the debit at exit. Inner ``try/except`` still catches
|
||||
provider errors and stores them on ``error_message`` (HTTP 200
|
||||
with ``error_message`` set is preserved for failed-but-not-quota
|
||||
scenarios — clients already know how to surface those).
|
||||
"""
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
|
|
@ -471,33 +567,70 @@ async def create_image_generation(
|
|||
if not search_space:
|
||||
raise HTTPException(status_code=404, detail="Search space not found")
|
||||
|
||||
db_image_gen = ImageGeneration(
|
||||
prompt=data.prompt,
|
||||
model=data.model,
|
||||
n=data.n,
|
||||
quality=data.quality,
|
||||
size=data.size,
|
||||
style=data.style,
|
||||
response_format=data.response_format,
|
||||
image_generation_config_id=data.image_generation_config_id,
|
||||
search_space_id=data.search_space_id,
|
||||
created_by_id=user.id,
|
||||
billing_tier, base_model, reserve_micros = await _resolve_billing_for_image_gen(
|
||||
session, data.image_generation_config_id, search_space
|
||||
)
|
||||
session.add(db_image_gen)
|
||||
await session.flush()
|
||||
|
||||
try:
|
||||
await _execute_image_generation(session, db_image_gen, search_space)
|
||||
except Exception as e:
|
||||
logger.exception("Image generation call failed")
|
||||
db_image_gen.error_message = str(e)
|
||||
# billable_call runs OUTSIDE the inner try/except so QuotaInsufficientError
|
||||
# propagates to the outer ``except QuotaInsufficientError`` handler
|
||||
# below as HTTP 402 — it is intentionally NOT swallowed into
|
||||
# ``error_message`` because that would (1) imply a successful row
|
||||
# exists when none does, and (2) return HTTP 200 to a client
|
||||
# whose request was actively *denied* (issue K).
|
||||
async with billable_call(
|
||||
user_id=search_space.user_id,
|
||||
search_space_id=data.search_space_id,
|
||||
billing_tier=billing_tier,
|
||||
base_model=base_model,
|
||||
quota_reserve_micros_override=reserve_micros,
|
||||
usage_type="image_generation",
|
||||
call_details={"model": base_model, "prompt": data.prompt[:100]},
|
||||
):
|
||||
db_image_gen = ImageGeneration(
|
||||
prompt=data.prompt,
|
||||
model=data.model,
|
||||
n=data.n,
|
||||
quality=data.quality,
|
||||
size=data.size,
|
||||
style=data.style,
|
||||
response_format=data.response_format,
|
||||
image_generation_config_id=data.image_generation_config_id,
|
||||
search_space_id=data.search_space_id,
|
||||
created_by_id=user.id,
|
||||
)
|
||||
session.add(db_image_gen)
|
||||
await session.flush()
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(db_image_gen)
|
||||
return db_image_gen
|
||||
try:
|
||||
await _execute_image_generation(session, db_image_gen, search_space)
|
||||
except Exception as e:
|
||||
logger.exception("Image generation call failed")
|
||||
db_image_gen.error_message = str(e)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(db_image_gen)
|
||||
return db_image_gen
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except QuotaInsufficientError as exc:
|
||||
# The user's premium credit pool is empty. No DB row is created
|
||||
# because ``billable_call`` denies before yielding (issue K).
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail={
|
||||
"error_code": "premium_quota_exhausted",
|
||||
"usage_type": exc.usage_type,
|
||||
"used_micros": exc.used_micros,
|
||||
"limit_micros": exc.limit_micros,
|
||||
"remaining_micros": exc.remaining_micros,
|
||||
"message": (
|
||||
"Out of premium credits for image generation. "
|
||||
"Purchase additional credits or switch to a free model."
|
||||
),
|
||||
},
|
||||
) from exc
|
||||
except SQLAlchemyError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
|
|
|
|||
|
|
@ -15,9 +15,10 @@ import json
|
|||
import logging
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy import func, or_
|
||||
from sqlalchemy import func, or_, text as sa_text
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
|
@ -29,6 +30,12 @@ from app.agents.new_chat.filesystem_selection import (
|
|||
FilesystemSelection,
|
||||
LocalFilesystemMount,
|
||||
)
|
||||
from app.agents.new_chat.middleware.busy_mutex import (
|
||||
get_cancel_state,
|
||||
is_cancel_requested,
|
||||
manager,
|
||||
request_cancel,
|
||||
)
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
ChatComment,
|
||||
|
|
@ -38,12 +45,14 @@ from app.db import (
|
|||
NewChatThread,
|
||||
Permission,
|
||||
SearchSpace,
|
||||
TokenUsage,
|
||||
User,
|
||||
get_async_session,
|
||||
shielded_async_session,
|
||||
)
|
||||
from app.schemas.new_chat import (
|
||||
AgentToolInfo,
|
||||
CancelActiveTurnResponse,
|
||||
LocalFilesystemMountPayload,
|
||||
NewChatMessageRead,
|
||||
NewChatRequest,
|
||||
|
|
@ -60,10 +69,11 @@ from app.schemas.new_chat import (
|
|||
ThreadListItem,
|
||||
ThreadListResponse,
|
||||
TokenUsageSummary,
|
||||
TurnStatusResponse,
|
||||
)
|
||||
from app.services.token_tracking_service import record_token_usage
|
||||
from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
|
||||
from app.users import current_active_user
|
||||
from app.utils.perf import get_perf_logger
|
||||
from app.utils.rbac import check_permission
|
||||
from app.utils.user_message_multimodal import (
|
||||
split_langchain_human_content,
|
||||
|
|
@ -71,7 +81,11 @@ from app.utils.user_message_multimodal import (
|
|||
)
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
_perf_log = get_perf_logger()
|
||||
_background_tasks: set[asyncio.Task] = set()
|
||||
TURN_CANCELLING_INITIAL_DELAY_MS = 200
|
||||
TURN_CANCELLING_BACKOFF_FACTOR = 2
|
||||
TURN_CANCELLING_MAX_DELAY_MS = 1500
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
@ -137,6 +151,72 @@ def _resolve_filesystem_selection(
|
|||
)
|
||||
|
||||
|
||||
def _compute_turn_cancelling_retry_delay(attempt: int) -> int:
|
||||
"""Bounded exponential delay for TURN_CANCELLING retry hints."""
|
||||
if attempt < 1:
|
||||
attempt = 1
|
||||
delay = TURN_CANCELLING_INITIAL_DELAY_MS * (
|
||||
TURN_CANCELLING_BACKOFF_FACTOR ** (attempt - 1)
|
||||
)
|
||||
return min(delay, TURN_CANCELLING_MAX_DELAY_MS)
|
||||
|
||||
|
||||
def _build_turn_status_payload(thread_id: int) -> dict[str, object]:
|
||||
lock = manager.lock_for(str(thread_id))
|
||||
if not lock.locked():
|
||||
return {"status": "idle"}
|
||||
|
||||
if is_cancel_requested(str(thread_id)):
|
||||
cancel_state = get_cancel_state(str(thread_id))
|
||||
attempt = cancel_state[0] if cancel_state else 1
|
||||
retry_after_ms = _compute_turn_cancelling_retry_delay(attempt)
|
||||
retry_after_at = int(datetime.now(UTC).timestamp() * 1000) + retry_after_ms
|
||||
return {
|
||||
"status": "cancelling",
|
||||
"retry_after_ms": retry_after_ms,
|
||||
"retry_after_at": retry_after_at,
|
||||
}
|
||||
|
||||
return {"status": "busy"}
|
||||
|
||||
|
||||
def _set_retry_after_headers(response: Response, retry_after_ms: int) -> None:
|
||||
response.headers["retry-after-ms"] = str(retry_after_ms)
|
||||
response.headers["Retry-After"] = str(max(1, (retry_after_ms + 999) // 1000))
|
||||
|
||||
|
||||
def _raise_if_thread_busy_for_start(thread_id: int) -> None:
|
||||
status_payload = _build_turn_status_payload(thread_id)
|
||||
status = status_payload["status"]
|
||||
if status == "idle":
|
||||
return
|
||||
if status == "cancelling":
|
||||
retry_after_ms = int(status_payload.get("retry_after_ms") or 0)
|
||||
detail = {
|
||||
"errorCode": "TURN_CANCELLING",
|
||||
"message": "A previous response is still stopping. Please try again in a moment.",
|
||||
"retry_after_ms": retry_after_ms if retry_after_ms > 0 else None,
|
||||
"retry_after_at": status_payload.get("retry_after_at"),
|
||||
}
|
||||
headers = (
|
||||
{
|
||||
"retry-after-ms": str(retry_after_ms),
|
||||
"Retry-After": str(max(1, (retry_after_ms + 999) // 1000)),
|
||||
}
|
||||
if retry_after_ms > 0
|
||||
else None
|
||||
)
|
||||
raise HTTPException(status_code=409, detail=detail, headers=headers)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail={
|
||||
"errorCode": "THREAD_BUSY",
|
||||
"message": "Another response is still finishing for this thread. Please try again in a moment.",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _find_pre_turn_checkpoint_id(
|
||||
checkpoint_tuples: list,
|
||||
*,
|
||||
|
|
@ -1210,6 +1290,24 @@ async def append_message(
|
|||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
.. deprecated:: 2026-05
|
||||
Replaced by the **SSE-based message ID handshake**. The streaming
|
||||
generator (`stream_new_chat` / `stream_resume_chat`) now persists
|
||||
both the user and assistant rows server-side via
|
||||
``persist_user_turn`` / ``persist_assistant_shell`` and emits
|
||||
``data-user-message-id`` / ``data-assistant-message-id`` SSE events
|
||||
so the frontend can rename its optimistic IDs in real time. The
|
||||
new FE bundle no longer calls this route.
|
||||
|
||||
This handler is retained as a **silent no-op for legacy / cached
|
||||
FE bundles**: the underlying ``INSERT ... ON CONFLICT DO NOTHING``
|
||||
pattern means a stale bundle hitting this route after the SSE
|
||||
handshake already wrote the row simply returns the existing row
|
||||
(200 OK) without raising or duplicating data. After a 2-week soak
|
||||
(target: ``[persist_user_turn] outcome=race_recovered`` rate ~0)
|
||||
this entire route — and the FE ``appendMessage`` function — is
|
||||
earmarked for removal.
|
||||
|
||||
Append a message to a thread.
|
||||
This is used by ThreadHistoryAdapter.append() to persist messages.
|
||||
|
||||
|
|
@ -1220,6 +1318,22 @@ async def append_message(
|
|||
Requires CHATS_UPDATE permission.
|
||||
"""
|
||||
try:
|
||||
# Capture ``user.id`` as a primitive UUID up front. The
|
||||
# ``current_active_user`` dependency hands us a ``User`` ORM
|
||||
# row bound to ``session``; if the outer ``except
|
||||
# IntegrityError`` block below ever fires (an unexpected
|
||||
# constraint like a foreign key violation — the common
|
||||
# ``(thread_id, turn_id, role)`` race is now handled silently
|
||||
# by ``ON CONFLICT DO NOTHING`` so it never raises) it calls
|
||||
# ``session.rollback()``, which expires every attached ORM
|
||||
# row including this user. Any later ``user.id`` access would
|
||||
# then trigger a lazy PK reload — which on async SQLAlchemy
|
||||
# fails with ``MissingGreenlet`` because the reload happens
|
||||
# outside the awaitable greenlet boundary. Reading ``id``
|
||||
# once here pins the value as a plain UUID so all downstream
|
||||
# uses (TokenUsage insert, response build) are immune.
|
||||
user_uuid = user.id
|
||||
|
||||
# Parse raw body - extract only role and content, ignoring extra fields
|
||||
raw_body = await request.json()
|
||||
role = raw_body.get("role")
|
||||
|
|
@ -1274,37 +1388,166 @@ async def append_message(
|
|||
else None
|
||||
)
|
||||
|
||||
db_message = NewChatMessage(
|
||||
thread_id=thread_id,
|
||||
role=message_role,
|
||||
content=content,
|
||||
author_id=user.id,
|
||||
turn_id=turn_id_value,
|
||||
)
|
||||
session.add(db_message)
|
||||
|
||||
# Update thread's updated_at timestamp
|
||||
# Update thread's updated_at timestamp (always — both insert
|
||||
# and recovery paths represent thread activity).
|
||||
thread.updated_at = datetime.now(UTC)
|
||||
|
||||
# flush assigns the PK/defaults without a round-trip SELECT
|
||||
await session.flush()
|
||||
# Insert the new message via ``INSERT ... ON CONFLICT DO NOTHING``
|
||||
# keyed on the ``(thread_id, turn_id, role)`` partial unique
|
||||
# index from migration 141 (``WHERE turn_id IS NOT NULL``).
|
||||
#
|
||||
# Why ON CONFLICT instead of ``session.add() + flush() + except
|
||||
# IntegrityError``:
|
||||
# 1. The conflict between this legacy FE ``appendMessage``
|
||||
# round-trip and the server-side
|
||||
# ``finalize_assistant_turn`` writer is a NORMAL,
|
||||
# *expected* race — every assistant turn fires it. Using
|
||||
# catch-and-recover means asyncpg raises
|
||||
# ``UniqueViolationError`` -> SQLAlchemy wraps it as
|
||||
# ``IntegrityError`` -> our handler catches and recovers.
|
||||
# Functionally fine, but every ``raise`` event lights up
|
||||
# VS Code's debugger (debugpy's ``justMyCode=false`` mode
|
||||
# loses track of the catch frame across SQLAlchemy's
|
||||
# async greenlet boundary, so even ``Raised Exceptions``
|
||||
# being unchecked doesn't reliably suppress the pause).
|
||||
# ON CONFLICT pushes the conflict resolution into Postgres
|
||||
# where no Python exception is constructed at all.
|
||||
# 2. No ``session.rollback()`` -> no expiring of attached
|
||||
# ORM rows -> no risk of ``MissingGreenlet`` from
|
||||
# lazy-loading expired user/thread state later in the
|
||||
# handler.
|
||||
# 3. Cleaner production logs (no SQLAlchemy ``IntegrityError``
|
||||
# tracebacks emitted by uvicorn's logger between the
|
||||
# ``raise`` and our ``except``).
|
||||
#
|
||||
# When ``turn_id_value`` is ``None`` the partial index doesn't
|
||||
# apply and the INSERT proceeds normally. Other constraint
|
||||
# violations (FK, NOT NULL, etc.) still raise ``IntegrityError``
|
||||
# and are caught by the outer ``except IntegrityError`` block
|
||||
# to preserve the legacy 400 behavior.
|
||||
#
|
||||
# Note on ``content``: when we recover the existing row, we
|
||||
# intentionally discard the FE's ``content`` payload from
|
||||
# ``raw_body`` and return the row's existing ``content``. The
|
||||
# streaming task is now the *authoritative writer* for
|
||||
# assistant ``ContentPart[]`` shape (mid-stream
|
||||
# ``AssistantContentBuilder`` -> ``finalize_assistant_turn``)
|
||||
# so the FE's later ``appendMessage`` is just a stale snapshot
|
||||
# of the same data — keeping the server-built rich content
|
||||
# (with full tool-call args / argsText / langchainToolCallId)
|
||||
# is correct, not lossy.
|
||||
insert_stmt = (
|
||||
pg_insert(NewChatMessage)
|
||||
.values(
|
||||
thread_id=thread_id,
|
||||
role=message_role,
|
||||
content=content,
|
||||
author_id=user_uuid,
|
||||
turn_id=turn_id_value,
|
||||
)
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=["thread_id", "turn_id", "role"],
|
||||
index_where=sa_text("turn_id IS NOT NULL"),
|
||||
)
|
||||
.returning(NewChatMessage.id)
|
||||
)
|
||||
inserted_id = (await session.execute(insert_stmt)).scalar()
|
||||
|
||||
# Persist token usage if provided (for assistant messages)
|
||||
if inserted_id is None:
|
||||
# Conflict on partial unique index — server-side stream
|
||||
# already wrote this row. Look it up and reuse it.
|
||||
if turn_id_value is None:
|
||||
# Defensive: ON CONFLICT only fires for ``turn_id IS
|
||||
# NOT NULL`` rows, so this branch should be
|
||||
# unreachable. Preserve the legacy 400 just in case
|
||||
# Postgres ever surprises us.
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Database constraint violation. Please check your input data.",
|
||||
) from None
|
||||
lookup = await session.execute(
|
||||
select(NewChatMessage).filter(
|
||||
NewChatMessage.thread_id == thread_id,
|
||||
NewChatMessage.turn_id == turn_id_value,
|
||||
NewChatMessage.role == message_role,
|
||||
)
|
||||
)
|
||||
existing_message = lookup.scalars().first()
|
||||
if existing_message is None:
|
||||
# Conflict reported but the row vanished between
|
||||
# INSERT and SELECT — extremely unlikely (would
|
||||
# require a concurrent DELETE within the same
|
||||
# transaction visibility), but preserve safe
|
||||
# behavior.
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Database constraint violation. Please check your input data.",
|
||||
) from None
|
||||
db_message = existing_message
|
||||
# Perf signal: counts how often the legacy FE round-trip
|
||||
# races the server-side ``finalize_assistant_turn``. A
|
||||
# rising rate after the rework is OK (it's exactly the
|
||||
# ghost-thread fix's recovery path firing); a sudden drop
|
||||
# to zero would mean the FE isn't posting appendMessage
|
||||
# at all (different bug).
|
||||
_perf_log.info(
|
||||
"[append_message] outcome=recovered_via_unique_index "
|
||||
"thread_id=%s turn_id=%s role=%s message_id=%s",
|
||||
thread_id,
|
||||
turn_id_value,
|
||||
message_role.value,
|
||||
db_message.id,
|
||||
)
|
||||
else:
|
||||
# INSERT succeeded — load the full ORM row so the
|
||||
# response can include server-side-defaulted columns
|
||||
# (``created_at``, etc.) and the relationship surface
|
||||
# stays consistent with the recovery path.
|
||||
inserted_row = await session.get(NewChatMessage, inserted_id)
|
||||
if inserted_row is None:
|
||||
# Should be impossible: we just inserted it in this
|
||||
# same transaction. Fail loud if it happens.
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Inserted message could not be loaded.",
|
||||
) from None
|
||||
db_message = inserted_row
|
||||
|
||||
# Persist token usage if provided (for assistant messages).
|
||||
# ``cost_micros`` is the provider USD cost reported by LiteLLM,
|
||||
# forwarded by the FE through the appendMessage round-trip so
|
||||
# the historical TokenUsage row matches the credit debit applied
|
||||
# at finalize time.
|
||||
#
|
||||
# De-dup: ``finalize_assistant_turn`` may also race to write a
|
||||
# token_usage row for this same ``message_id`` (cross-session,
|
||||
# cross-shielded). Use ``INSERT ... ON CONFLICT DO NOTHING`` keyed
|
||||
# on the ``uq_token_usage_message_id`` partial unique index
|
||||
# (migration 142). The loser silently drops its insert; exactly
|
||||
# one row results regardless of which writer commits first.
|
||||
token_usage_data = raw_body.get("token_usage")
|
||||
if token_usage_data and message_role == NewChatMessageRole.ASSISTANT:
|
||||
await record_token_usage(
|
||||
session,
|
||||
usage_type="chat",
|
||||
search_space_id=thread.search_space_id,
|
||||
user_id=user.id,
|
||||
prompt_tokens=token_usage_data.get("prompt_tokens", 0),
|
||||
completion_tokens=token_usage_data.get("completion_tokens", 0),
|
||||
total_tokens=token_usage_data.get("total_tokens", 0),
|
||||
model_breakdown=token_usage_data.get("usage"),
|
||||
call_details=token_usage_data.get("call_details"),
|
||||
thread_id=thread_id,
|
||||
message_id=db_message.id,
|
||||
insert_stmt = (
|
||||
pg_insert(TokenUsage)
|
||||
.values(
|
||||
usage_type="chat",
|
||||
prompt_tokens=token_usage_data.get("prompt_tokens", 0),
|
||||
completion_tokens=token_usage_data.get("completion_tokens", 0),
|
||||
total_tokens=token_usage_data.get("total_tokens", 0),
|
||||
cost_micros=token_usage_data.get("cost_micros", 0),
|
||||
model_breakdown=token_usage_data.get("usage"),
|
||||
call_details=token_usage_data.get("call_details"),
|
||||
thread_id=thread_id,
|
||||
message_id=db_message.id,
|
||||
search_space_id=thread.search_space_id,
|
||||
user_id=user_uuid,
|
||||
)
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=["message_id"],
|
||||
index_where=sa_text("message_id IS NOT NULL"),
|
||||
)
|
||||
)
|
||||
await session.execute(insert_stmt)
|
||||
|
||||
await session.commit()
|
||||
|
||||
|
|
@ -1324,6 +1567,9 @@ async def append_message(
|
|||
except HTTPException:
|
||||
raise
|
||||
except IntegrityError:
|
||||
# Any IntegrityError that escaped the inline handler above
|
||||
# comes from a *different* constraint (foreign key, etc.) —
|
||||
# preserve the legacy 400 path.
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
|
|
@ -1476,6 +1722,7 @@ async def handle_new_chat(
|
|||
|
||||
# Check thread-level access based on visibility
|
||||
await check_thread_access(session, thread, user)
|
||||
_raise_if_thread_busy_for_start(request.chat_id)
|
||||
filesystem_selection = _resolve_filesystem_selection(
|
||||
mode=request.filesystem_mode,
|
||||
client_platform=request.client_platform,
|
||||
|
|
@ -1516,6 +1763,12 @@ async def handle_new_chat(
|
|||
else None
|
||||
)
|
||||
|
||||
mentioned_documents_payload = (
|
||||
[doc.model_dump() for doc in request.mentioned_documents]
|
||||
if request.mentioned_documents
|
||||
else None
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
stream_new_chat(
|
||||
user_query=request.user_query,
|
||||
|
|
@ -1525,6 +1778,7 @@ async def handle_new_chat(
|
|||
llm_config_id=llm_config_id,
|
||||
mentioned_document_ids=request.mentioned_document_ids,
|
||||
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
|
||||
mentioned_documents=mentioned_documents_payload,
|
||||
needs_history_bootstrap=thread.needs_history_bootstrap,
|
||||
thread_visibility=thread.visibility,
|
||||
current_user_display_name=user.display_name or "A team member",
|
||||
|
|
@ -1550,6 +1804,93 @@ async def handle_new_chat(
|
|||
) from None
|
||||
|
||||
|
||||
@router.post(
|
||||
"/threads/{thread_id}/cancel-active-turn",
|
||||
response_model=CancelActiveTurnResponse,
|
||||
)
|
||||
async def cancel_active_turn(
|
||||
thread_id: int,
|
||||
response: Response,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Signal cancellation for the currently running turn on ``thread_id``."""
|
||||
result = await session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == thread_id)
|
||||
)
|
||||
thread = result.scalars().first()
|
||||
if not thread:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
thread.search_space_id,
|
||||
Permission.CHATS_UPDATE.value,
|
||||
"You don't have permission to update chats in this search space",
|
||||
)
|
||||
await check_thread_access(session, thread, user)
|
||||
|
||||
status_payload = _build_turn_status_payload(thread_id)
|
||||
if status_payload["status"] == "idle":
|
||||
return CancelActiveTurnResponse(
|
||||
status="idle",
|
||||
error_code="NO_ACTIVE_TURN",
|
||||
)
|
||||
|
||||
request_cancel(str(thread_id))
|
||||
response.status_code = 202
|
||||
updated_payload = _build_turn_status_payload(thread_id)
|
||||
retry_after_ms = int(updated_payload.get("retry_after_ms") or 0)
|
||||
retry_after_at = (
|
||||
int(updated_payload["retry_after_at"])
|
||||
if "retry_after_at" in updated_payload
|
||||
else None
|
||||
)
|
||||
if retry_after_ms > 0:
|
||||
_set_retry_after_headers(response, retry_after_ms)
|
||||
return CancelActiveTurnResponse(
|
||||
status="cancelling",
|
||||
error_code="TURN_CANCELLING",
|
||||
retry_after_ms=retry_after_ms if retry_after_ms > 0 else None,
|
||||
retry_after_at=retry_after_at,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/threads/{thread_id}/turn-status",
|
||||
response_model=TurnStatusResponse,
|
||||
)
|
||||
async def get_turn_status(
|
||||
thread_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
result = await session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == thread_id)
|
||||
)
|
||||
thread = result.scalars().first()
|
||||
if not thread:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
thread.search_space_id,
|
||||
Permission.CHATS_READ.value,
|
||||
"You don't have permission to view chats in this search space",
|
||||
)
|
||||
await check_thread_access(session, thread, user)
|
||||
|
||||
status_payload = _build_turn_status_payload(thread_id)
|
||||
return TurnStatusResponse(
|
||||
status=status_payload["status"], # type: ignore[arg-type]
|
||||
active_turn_id=None,
|
||||
retry_after_ms=status_payload.get("retry_after_ms"), # type: ignore[arg-type]
|
||||
retry_after_at=status_payload.get("retry_after_at"), # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Chat Regeneration Endpoint (Edit/Reload)
|
||||
# =============================================================================
|
||||
|
|
@ -1605,6 +1946,7 @@ async def regenerate_response(
|
|||
|
||||
# Check thread-level access based on visibility
|
||||
await check_thread_access(session, thread, user)
|
||||
_raise_if_thread_busy_for_start(thread_id)
|
||||
filesystem_selection = _resolve_filesystem_selection(
|
||||
mode=request.filesystem_mode,
|
||||
client_platform=request.client_platform,
|
||||
|
|
@ -1907,6 +2249,11 @@ async def regenerate_response(
|
|||
"data": revert_results,
|
||||
}
|
||||
yield f"data: {json.dumps(envelope, default=str)}\n\n".encode()
|
||||
mentioned_documents_payload = (
|
||||
[doc.model_dump() for doc in request.mentioned_documents]
|
||||
if request.mentioned_documents
|
||||
else None
|
||||
)
|
||||
try:
|
||||
async for chunk in stream_new_chat(
|
||||
user_query=str(user_query_to_use),
|
||||
|
|
@ -1916,6 +2263,7 @@ async def regenerate_response(
|
|||
llm_config_id=llm_config_id,
|
||||
mentioned_document_ids=request.mentioned_document_ids,
|
||||
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
|
||||
mentioned_documents=mentioned_documents_payload,
|
||||
checkpoint_id=target_checkpoint_id,
|
||||
needs_history_bootstrap=thread.needs_history_bootstrap,
|
||||
thread_visibility=thread.visibility,
|
||||
|
|
@ -1924,6 +2272,7 @@ async def regenerate_response(
|
|||
filesystem_selection=filesystem_selection,
|
||||
request_id=getattr(http_request.state, "request_id", "unknown"),
|
||||
user_image_data_urls=regenerate_image_urls or None,
|
||||
flow="regenerate",
|
||||
):
|
||||
yield chunk
|
||||
streaming_completed = True
|
||||
|
|
@ -2011,6 +2360,7 @@ async def resume_chat(
|
|||
)
|
||||
|
||||
await check_thread_access(session, thread, user)
|
||||
_raise_if_thread_busy_for_start(thread_id)
|
||||
filesystem_selection = _resolve_filesystem_selection(
|
||||
mode=request.filesystem_mode,
|
||||
client_platform=request.client_platform,
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from app.schemas import (
|
|||
NewLLMConfigUpdate,
|
||||
)
|
||||
from app.services.llm_service import validate_llm_config
|
||||
from app.services.provider_capabilities import derive_supports_image_input
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
|
|
@ -36,6 +37,39 @@ router = APIRouter()
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _serialize_byok_config(config: NewLLMConfig) -> NewLLMConfigRead:
|
||||
"""Augment a BYOK chat config row with the derived ``supports_image_input``.
|
||||
|
||||
There is no DB column for ``supports_image_input`` — the value is
|
||||
resolved at the API boundary from LiteLLM's authoritative model map
|
||||
(default-allow on unknown). Returning ``NewLLMConfigRead`` here keeps
|
||||
the response shape consistent across list / detail / create / update
|
||||
endpoints without having to remember to set the field at every call
|
||||
site.
|
||||
"""
|
||||
provider_value = (
|
||||
config.provider.value
|
||||
if hasattr(config.provider, "value")
|
||||
else str(config.provider)
|
||||
)
|
||||
litellm_params = config.litellm_params or {}
|
||||
base_model = (
|
||||
litellm_params.get("base_model") if isinstance(litellm_params, dict) else None
|
||||
)
|
||||
supports_image_input = derive_supports_image_input(
|
||||
provider=provider_value,
|
||||
model_name=config.model_name,
|
||||
base_model=base_model,
|
||||
custom_provider=config.custom_provider,
|
||||
)
|
||||
# ``model_validate`` runs the Pydantic conversion using the ORM
|
||||
# attribute access path enabled by ``ConfigDict(from_attributes=True)``,
|
||||
# then we layer the derived field on. ``model_copy(update=...)`` keeps
|
||||
# the surface immutable from the caller's perspective.
|
||||
base_read = NewLLMConfigRead.model_validate(config)
|
||||
return base_read.model_copy(update={"supports_image_input": supports_image_input})
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Global Configs Routes
|
||||
# =============================================================================
|
||||
|
|
@ -84,11 +118,41 @@ async def get_global_new_llm_configs(
|
|||
"seo_title": None,
|
||||
"seo_description": None,
|
||||
"quota_reserve_tokens": None,
|
||||
# Auto routes across the configured pool, which usually
|
||||
# includes at least one vision-capable deployment, so
|
||||
# treat Auto as image-capable. The router itself will
|
||||
# still pick a vision-capable deployment for messages
|
||||
# carrying image_url blocks (LiteLLM Router falls back
|
||||
# on ``404`` per its ``allowed_fails`` policy).
|
||||
"supports_image_input": True,
|
||||
}
|
||||
)
|
||||
|
||||
# Add individual global configs
|
||||
for cfg in global_configs:
|
||||
# Capability resolution: explicit value (YAML override or OR
|
||||
# `_supports_image_input(model)` payload baked in by the
|
||||
# OpenRouter integration service) wins. Fall back to the
|
||||
# LiteLLM-driven helper which default-allows on unknown so
|
||||
# we don't hide vision-capable models that happen to lack a
|
||||
# YAML annotation. The streaming task safety net is the
|
||||
# only place a False ever blocks.
|
||||
if "supports_image_input" in cfg:
|
||||
supports_image_input = bool(cfg.get("supports_image_input"))
|
||||
else:
|
||||
cfg_litellm_params = cfg.get("litellm_params") or {}
|
||||
cfg_base_model = (
|
||||
cfg_litellm_params.get("base_model")
|
||||
if isinstance(cfg_litellm_params, dict)
|
||||
else None
|
||||
)
|
||||
supports_image_input = derive_supports_image_input(
|
||||
provider=cfg.get("provider"),
|
||||
model_name=cfg.get("model_name"),
|
||||
base_model=cfg_base_model,
|
||||
custom_provider=cfg.get("custom_provider"),
|
||||
)
|
||||
|
||||
safe_config = {
|
||||
"id": cfg.get("id"),
|
||||
"name": cfg.get("name"),
|
||||
|
|
@ -113,6 +177,7 @@ async def get_global_new_llm_configs(
|
|||
"seo_title": cfg.get("seo_title"),
|
||||
"seo_description": cfg.get("seo_description"),
|
||||
"quota_reserve_tokens": cfg.get("quota_reserve_tokens"),
|
||||
"supports_image_input": supports_image_input,
|
||||
}
|
||||
safe_configs.append(safe_config)
|
||||
|
||||
|
|
@ -171,7 +236,7 @@ async def create_new_llm_config(
|
|||
await session.commit()
|
||||
await session.refresh(db_config)
|
||||
|
||||
return db_config
|
||||
return _serialize_byok_config(db_config)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
@ -213,7 +278,7 @@ async def list_new_llm_configs(
|
|||
.limit(limit)
|
||||
)
|
||||
|
||||
return result.scalars().all()
|
||||
return [_serialize_byok_config(cfg) for cfg in result.scalars().all()]
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
@ -268,7 +333,7 @@ async def get_new_llm_config(
|
|||
"You don't have permission to view LLM configurations in this search space",
|
||||
)
|
||||
|
||||
return config
|
||||
return _serialize_byok_config(config)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
@ -360,7 +425,7 @@ async def update_new_llm_config(
|
|||
await session.commit()
|
||||
await session.refresh(config)
|
||||
|
||||
return config
|
||||
return _serialize_byok_config(config)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import logging
|
|||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from langchain_core.messages import HumanMessage
|
||||
from pydantic import BaseModel as PydanticBaseModel
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import func, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
|
|
@ -15,6 +15,7 @@ from app.agents.new_chat.tools.update_memory import MEMORY_HARD_LIMIT, _save_mem
|
|||
from app.config import config
|
||||
from app.db import (
|
||||
ImageGenerationConfig,
|
||||
NewChatThread,
|
||||
NewLLMConfig,
|
||||
Permission,
|
||||
SearchSpace,
|
||||
|
|
@ -593,6 +594,7 @@ async def _get_image_gen_config_by_id(
|
|||
"model_name": "auto",
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
"billing_tier": "free",
|
||||
}
|
||||
|
||||
if config_id < 0:
|
||||
|
|
@ -609,6 +611,7 @@ async def _get_image_gen_config_by_id(
|
|||
"api_version": cfg.get("api_version") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
"billing_tier": cfg.get("billing_tier", "free"),
|
||||
}
|
||||
return None
|
||||
|
||||
|
|
@ -651,6 +654,7 @@ async def _get_vision_llm_config_by_id(
|
|||
"model_name": "auto",
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
"billing_tier": "free",
|
||||
}
|
||||
|
||||
if config_id < 0:
|
||||
|
|
@ -667,6 +671,7 @@ async def _get_vision_llm_config_by_id(
|
|||
"api_version": cfg.get("api_version") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
"billing_tier": cfg.get("billing_tier", "free"),
|
||||
}
|
||||
return None
|
||||
|
||||
|
|
@ -790,9 +795,27 @@ async def update_llm_preferences(
|
|||
|
||||
# Update preferences
|
||||
update_data = preferences.model_dump(exclude_unset=True)
|
||||
previous_agent_llm_id = search_space.agent_llm_id
|
||||
for key, value in update_data.items():
|
||||
setattr(search_space, key, value)
|
||||
|
||||
agent_llm_changed = (
|
||||
"agent_llm_id" in update_data
|
||||
and update_data["agent_llm_id"] != previous_agent_llm_id
|
||||
)
|
||||
if agent_llm_changed:
|
||||
await session.execute(
|
||||
update(NewChatThread)
|
||||
.where(NewChatThread.search_space_id == search_space_id)
|
||||
.values(pinned_llm_config_id=None)
|
||||
)
|
||||
logger.info(
|
||||
"Cleared auto model pins for search_space_id=%s after agent_llm_id change (%s -> %s)",
|
||||
search_space_id,
|
||||
previous_agent_llm_id,
|
||||
update_data["agent_llm_id"],
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(search_space)
|
||||
|
||||
|
|
|
|||
|
|
@ -251,9 +251,16 @@ async def _fulfill_completed_token_purchase(
|
|||
metadata = _get_metadata(checkout_session)
|
||||
user_id = metadata.get("user_id")
|
||||
quantity = int(metadata.get("quantity", "0"))
|
||||
tokens_per_unit = int(metadata.get("tokens_per_unit", "0"))
|
||||
# Read the new metadata key first, fall back to the legacy one so
|
||||
# in-flight checkout sessions created before the cost-credits
|
||||
# release still fulfil correctly (the unit is numerically the
|
||||
# same: $1 buys 1_000_000 micro-USD == 1_000_000 tokens).
|
||||
credit_micros_per_unit = int(
|
||||
metadata.get("credit_micros_per_unit")
|
||||
or metadata.get("tokens_per_unit", "0")
|
||||
)
|
||||
|
||||
if not user_id or quantity <= 0 or tokens_per_unit <= 0:
|
||||
if not user_id or quantity <= 0 or credit_micros_per_unit <= 0:
|
||||
logger.error(
|
||||
"Skipping token fulfillment for session %s: incomplete metadata %s",
|
||||
checkout_session_id,
|
||||
|
|
@ -268,7 +275,7 @@ async def _fulfill_completed_token_purchase(
|
|||
getattr(checkout_session, "payment_intent", None)
|
||||
),
|
||||
quantity=quantity,
|
||||
tokens_granted=quantity * tokens_per_unit,
|
||||
credit_micros_granted=quantity * credit_micros_per_unit,
|
||||
amount_total=getattr(checkout_session, "amount_total", None),
|
||||
currency=getattr(checkout_session, "currency", None),
|
||||
status=PremiumTokenPurchaseStatus.PENDING,
|
||||
|
|
@ -303,9 +310,14 @@ async def _fulfill_completed_token_purchase(
|
|||
purchase.stripe_payment_intent_id = _normalize_optional_string(
|
||||
getattr(checkout_session, "payment_intent", None)
|
||||
)
|
||||
user.premium_tokens_limit = (
|
||||
max(user.premium_tokens_used, user.premium_tokens_limit)
|
||||
+ purchase.tokens_granted
|
||||
# Top up the user's credit balance by the granted micro-USD amount.
|
||||
# ``max(used, limit)`` clamps the case where the legacy code wrote a
|
||||
# used value above the limit (e.g. underbilling rounding) so adding
|
||||
# ``credit_micros_granted`` always lifts the limit by the full pack
|
||||
# size rather than disappearing into past overuse.
|
||||
user.premium_credit_micros_limit = (
|
||||
max(user.premium_credit_micros_used, user.premium_credit_micros_limit)
|
||||
+ purchase.credit_micros_granted
|
||||
)
|
||||
|
||||
await db_session.commit()
|
||||
|
|
@ -532,12 +544,18 @@ async def create_token_checkout_session(
|
|||
user: User = Depends(current_active_user),
|
||||
db_session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Create a Stripe Checkout Session for buying premium token packs."""
|
||||
"""Create a Stripe Checkout Session for buying premium credit packs.
|
||||
|
||||
Each pack grants ``STRIPE_CREDIT_MICROS_PER_UNIT`` micro-USD of
|
||||
credit (default 1_000_000 = $1.00). The user's balance is debited
|
||||
at the actual provider cost reported by LiteLLM at finalize time,
|
||||
so $1 of credit always buys $1 worth of provider usage at cost.
|
||||
"""
|
||||
_ensure_token_buying_enabled()
|
||||
stripe_client = get_stripe_client()
|
||||
price_id = _get_required_token_price_id()
|
||||
success_url, cancel_url = _get_token_checkout_urls(body.search_space_id)
|
||||
tokens_granted = body.quantity * config.STRIPE_TOKENS_PER_UNIT
|
||||
credit_micros_granted = body.quantity * config.STRIPE_CREDIT_MICROS_PER_UNIT
|
||||
|
||||
try:
|
||||
checkout_session = stripe_client.v1.checkout.sessions.create(
|
||||
|
|
@ -556,8 +574,8 @@ async def create_token_checkout_session(
|
|||
"metadata": {
|
||||
"user_id": str(user.id),
|
||||
"quantity": str(body.quantity),
|
||||
"tokens_per_unit": str(config.STRIPE_TOKENS_PER_UNIT),
|
||||
"purchase_type": "premium_tokens",
|
||||
"credit_micros_per_unit": str(config.STRIPE_CREDIT_MICROS_PER_UNIT),
|
||||
"purchase_type": "premium_credit",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
|
@ -583,7 +601,7 @@ async def create_token_checkout_session(
|
|||
getattr(checkout_session, "payment_intent", None)
|
||||
),
|
||||
quantity=body.quantity,
|
||||
tokens_granted=tokens_granted,
|
||||
credit_micros_granted=credit_micros_granted,
|
||||
amount_total=getattr(checkout_session, "amount_total", None),
|
||||
currency=getattr(checkout_session, "currency", None),
|
||||
status=PremiumTokenPurchaseStatus.PENDING,
|
||||
|
|
@ -598,14 +616,19 @@ async def create_token_checkout_session(
|
|||
async def get_token_status(
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Return token-buying availability and current premium quota for frontend."""
|
||||
used = user.premium_tokens_used
|
||||
limit = user.premium_tokens_limit
|
||||
"""Return token-buying availability and current premium credit quota for frontend.
|
||||
|
||||
Values are in micro-USD (1_000_000 = $1.00); the FE divides by 1M
|
||||
when displaying. The route name is preserved for back-compat with
|
||||
pinned client deployments.
|
||||
"""
|
||||
used = user.premium_credit_micros_used
|
||||
limit = user.premium_credit_micros_limit
|
||||
return TokenStripeStatusResponse(
|
||||
token_buying_enabled=config.STRIPE_TOKEN_BUYING_ENABLED,
|
||||
premium_tokens_used=used,
|
||||
premium_tokens_limit=limit,
|
||||
premium_tokens_remaining=max(0, limit - used),
|
||||
premium_credit_micros_used=used,
|
||||
premium_credit_micros_limit=limit,
|
||||
premium_credit_micros_remaining=max(0, limit - used),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -82,10 +82,15 @@ async def get_global_vision_llm_configs(
|
|||
"litellm_params": {},
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
# Auto mode treated as free until per-deployment billing-tier
|
||||
# surfacing lands; see ``get_vision_llm`` for parity.
|
||||
"billing_tier": "free",
|
||||
"is_premium": False,
|
||||
}
|
||||
)
|
||||
|
||||
for cfg in global_configs:
|
||||
billing_tier = str(cfg.get("billing_tier", "free")).lower()
|
||||
safe_configs.append(
|
||||
{
|
||||
"id": cfg.get("id"),
|
||||
|
|
@ -98,6 +103,14 @@ async def get_global_vision_llm_configs(
|
|||
"api_version": cfg.get("api_version") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
"billing_tier": billing_tier,
|
||||
# Mirror chat (``new_llm_config_routes``) so the new-chat
|
||||
# selector's premium badge logic keys off the same
|
||||
# field across chat / image / vision tabs.
|
||||
"is_premium": billing_tier == "premium",
|
||||
"quota_reserve_tokens": cfg.get("quota_reserve_tokens"),
|
||||
"input_cost_per_token": cfg.get("input_cost_per_token"),
|
||||
"output_cost_per_token": cfg.get("output_cost_per_token"),
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue