Merge remote-tracking branch 'upstream/dev' into fix/memory-extraction

This commit is contained in:
Anish Sarkar 2026-05-04 12:03:44 +05:30
commit b981b51ab1
176 changed files with 20407 additions and 6258 deletions

View file

@ -23,6 +23,7 @@ from fastapi import APIRouter, Depends
from pydantic import BaseModel
from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags
from app.config import config
from app.db import User
from app.users import current_active_user
@ -58,10 +59,15 @@ class AgentFeatureFlagsRead(BaseModel):
enable_otel: bool
enable_desktop_local_filesystem: bool
@classmethod
def from_flags(cls, flags: AgentFeatureFlags) -> AgentFeatureFlagsRead:
# asdict() avoids missing-field bugs when AgentFeatureFlags grows.
return cls(**asdict(flags))
return cls(
**asdict(flags),
enable_desktop_local_filesystem=config.ENABLE_DESKTOP_LOCAL_FILESYSTEM,
)
@router.get("/agent/flags", response_model=AgentFeatureFlagsRead)

View file

@ -649,13 +649,9 @@ async def list_composio_drive_folders(
"""
List folders AND files in user's Google Drive via Composio.
Uses the same GoogleDriveClient / list_folder_contents path as the native
connector, with Composio-sourced credentials. This means auth errors
propagate identically (Google returns 401 exception auth_expired flag).
Uses Composio's Google Drive tool execution path so managed OAuth tokens
do not need to be exposed through connected account state.
"""
from app.connectors.google_drive import GoogleDriveClient, list_folder_contents
from app.utils.google_credentials import build_composio_credentials
if not ComposioService.is_enabled():
raise HTTPException(
status_code=503,
@ -689,10 +685,37 @@ async def list_composio_drive_folders(
detail="Composio connected account not found. Please reconnect the connector.",
)
credentials = build_composio_credentials(composio_connected_account_id)
drive_client = GoogleDriveClient(session, connector_id, credentials=credentials)
service = ComposioService()
entity_id = f"surfsense_{user.id}"
items = []
page_token = None
error = None
items, error = await list_folder_contents(drive_client, parent_id=parent_id)
while True:
page_items, next_token, page_error = await service.get_drive_files(
connected_account_id=composio_connected_account_id,
entity_id=entity_id,
folder_id=parent_id,
page_token=page_token,
page_size=100,
)
if page_error:
error = page_error
break
items.extend(page_items)
if not next_token:
break
page_token = next_token
for item in items:
item["isFolder"] = (
item.get("mimeType") == "application/vnd.google-apps.folder"
)
items.sort(
key=lambda item: (not item["isFolder"], item.get("name", "").lower())
)
if error:
error_lower = error.lower()

View file

@ -36,11 +36,17 @@ from app.schemas import (
ImageGenerationListRead,
ImageGenerationRead,
)
from app.services.billable_calls import (
DEFAULT_IMAGE_RESERVE_MICROS,
QuotaInsufficientError,
billable_call,
)
from app.services.image_gen_router_service import (
IMAGE_GEN_AUTO_MODE_ID,
ImageGenRouterService,
is_image_gen_auto_mode,
)
from app.services.provider_api_base import resolve_api_base
from app.users import current_active_user
from app.utils.rbac import check_permission
from app.utils.signed_image_urls import verify_image_token
@ -82,14 +88,62 @@ def _get_global_image_gen_config(config_id: int) -> dict | None:
return None
def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str:
"""Resolve the LiteLLM provider prefix used in model strings."""
if custom_provider:
return custom_provider
return _PROVIDER_MAP.get(provider.upper(), provider.lower())
def _build_model_string(
provider: str, model_name: str, custom_provider: str | None
) -> str:
"""Build a litellm model string from provider + model_name."""
if custom_provider:
return f"{custom_provider}/{model_name}"
prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower())
return f"{prefix}/{model_name}"
return f"{_resolve_provider_prefix(provider, custom_provider)}/{model_name}"
async def _resolve_billing_for_image_gen(
session: AsyncSession,
config_id: int | None,
search_space: SearchSpace,
) -> tuple[str, str, int]:
"""Resolve ``(billing_tier, base_model, reserve_micros)`` for a request.
The resolution mirrors ``_execute_image_generation``'s lookup tree but
only extracts the fields needed for billing we do this *before*
``billable_call`` so the reservation is correctly sized for the
config that will actually run, and so we don't open an
``ImageGeneration`` row for a request that's about to 402.
User-owned (positive ID) BYOK configs are always free they cost
the user nothing on our side. Auto mode currently treats as free
because the underlying router can dispatch to either premium or
free YAML configs and we don't surface the resolved deployment up
here yet. Bringing Auto under premium billing would require
threading the chosen deployment back from ``ImageGenRouterService``.
"""
resolved_id = config_id
if resolved_id is None:
resolved_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID
if is_image_gen_auto_mode(resolved_id):
return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS)
if resolved_id < 0:
cfg = _get_global_image_gen_config(resolved_id) or {}
billing_tier = str(cfg.get("billing_tier", "free")).lower()
base_model = _build_model_string(
cfg.get("provider", ""),
cfg.get("model_name", ""),
cfg.get("custom_provider"),
)
reserve_micros = int(
cfg.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS
)
return (billing_tier, base_model, reserve_micros)
# Positive ID = user-owned BYOK image-gen config — always free.
return ("free", "user_byok", DEFAULT_IMAGE_RESERVE_MICROS)
async def _execute_image_generation(
@ -138,12 +192,18 @@ async def _execute_image_generation(
if not cfg:
raise ValueError(f"Global image generation config {config_id} not found")
model_string = _build_model_string(
cfg.get("provider", ""), cfg["model_name"], cfg.get("custom_provider")
provider_prefix = _resolve_provider_prefix(
cfg.get("provider", ""), cfg.get("custom_provider")
)
model_string = f"{provider_prefix}/{cfg['model_name']}"
gen_kwargs["api_key"] = cfg.get("api_key")
if cfg.get("api_base"):
gen_kwargs["api_base"] = cfg["api_base"]
api_base = resolve_api_base(
provider=cfg.get("provider"),
provider_prefix=provider_prefix,
config_api_base=cfg.get("api_base"),
)
if api_base:
gen_kwargs["api_base"] = api_base
if cfg.get("api_version"):
gen_kwargs["api_version"] = cfg["api_version"]
if cfg.get("litellm_params"):
@ -165,12 +225,18 @@ async def _execute_image_generation(
if not db_cfg:
raise ValueError(f"Image generation config {config_id} not found")
model_string = _build_model_string(
db_cfg.provider.value, db_cfg.model_name, db_cfg.custom_provider
provider_prefix = _resolve_provider_prefix(
db_cfg.provider.value, db_cfg.custom_provider
)
model_string = f"{provider_prefix}/{db_cfg.model_name}"
gen_kwargs["api_key"] = db_cfg.api_key
if db_cfg.api_base:
gen_kwargs["api_base"] = db_cfg.api_base
api_base = resolve_api_base(
provider=db_cfg.provider.value,
provider_prefix=provider_prefix,
config_api_base=db_cfg.api_base,
)
if api_base:
gen_kwargs["api_base"] = api_base
if db_cfg.api_version:
gen_kwargs["api_version"] = db_cfg.api_version
if db_cfg.litellm_params:
@ -225,10 +291,15 @@ async def get_global_image_gen_configs(
"litellm_params": {},
"is_global": True,
"is_auto_mode": True,
# Auto mode currently treated as free until per-deployment
# billing-tier surfacing lands (see _resolve_billing_for_image_gen).
"billing_tier": "free",
"is_premium": False,
}
)
for cfg in global_configs:
billing_tier = str(cfg.get("billing_tier", "free")).lower()
safe_configs.append(
{
"id": cfg.get("id"),
@ -241,6 +312,12 @@ async def get_global_image_gen_configs(
"api_version": cfg.get("api_version") or None,
"litellm_params": cfg.get("litellm_params", {}),
"is_global": True,
"billing_tier": billing_tier,
# Mirror chat (``new_llm_config_routes``) so the new-chat
# selector's premium badge logic keys off the same
# field across chat / image / vision tabs.
"is_premium": billing_tier == "premium",
"quota_reserve_micros": cfg.get("quota_reserve_micros"),
}
)
@ -454,7 +531,26 @@ async def create_image_generation(
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Create and execute an image generation request."""
"""Create and execute an image generation request.
Premium configs are gated by the user's shared premium credit pool.
The flow is:
1. Permission check + load the search space (cheap, no provider call).
2. Resolve which config will run so we know its billing tier and the
worst-case reservation size *before* opening any DB rows.
3. Wrap the entire ImageGeneration row insert + provider call in
``billable_call``. If quota is denied, ``billable_call`` raises
``QuotaInsufficientError`` *before* we flush a row, which we
translate to HTTP 402 (no orphaned rows on the user's account,
no inserted error rows for "you ran out of credit").
4. On success, the actual ``response_cost`` flows through the
LiteLLM callback into the accumulator, and ``billable_call``
finalizes the debit at exit. Inner ``try/except`` still catches
provider errors and stores them on ``error_message`` (HTTP 200
with ``error_message`` set is preserved for failed-but-not-quota
scenarios clients already know how to surface those).
"""
try:
await check_permission(
session,
@ -471,33 +567,70 @@ async def create_image_generation(
if not search_space:
raise HTTPException(status_code=404, detail="Search space not found")
db_image_gen = ImageGeneration(
prompt=data.prompt,
model=data.model,
n=data.n,
quality=data.quality,
size=data.size,
style=data.style,
response_format=data.response_format,
image_generation_config_id=data.image_generation_config_id,
search_space_id=data.search_space_id,
created_by_id=user.id,
billing_tier, base_model, reserve_micros = await _resolve_billing_for_image_gen(
session, data.image_generation_config_id, search_space
)
session.add(db_image_gen)
await session.flush()
try:
await _execute_image_generation(session, db_image_gen, search_space)
except Exception as e:
logger.exception("Image generation call failed")
db_image_gen.error_message = str(e)
# billable_call runs OUTSIDE the inner try/except so QuotaInsufficientError
# propagates to the outer ``except QuotaInsufficientError`` handler
# below as HTTP 402 — it is intentionally NOT swallowed into
# ``error_message`` because that would (1) imply a successful row
# exists when none does, and (2) return HTTP 200 to a client
# whose request was actively *denied* (issue K).
async with billable_call(
user_id=search_space.user_id,
search_space_id=data.search_space_id,
billing_tier=billing_tier,
base_model=base_model,
quota_reserve_micros_override=reserve_micros,
usage_type="image_generation",
call_details={"model": base_model, "prompt": data.prompt[:100]},
):
db_image_gen = ImageGeneration(
prompt=data.prompt,
model=data.model,
n=data.n,
quality=data.quality,
size=data.size,
style=data.style,
response_format=data.response_format,
image_generation_config_id=data.image_generation_config_id,
search_space_id=data.search_space_id,
created_by_id=user.id,
)
session.add(db_image_gen)
await session.flush()
await session.commit()
await session.refresh(db_image_gen)
return db_image_gen
try:
await _execute_image_generation(session, db_image_gen, search_space)
except Exception as e:
logger.exception("Image generation call failed")
db_image_gen.error_message = str(e)
await session.commit()
await session.refresh(db_image_gen)
return db_image_gen
except HTTPException:
raise
except QuotaInsufficientError as exc:
# The user's premium credit pool is empty. No DB row is created
# because ``billable_call`` denies before yielding (issue K).
await session.rollback()
raise HTTPException(
status_code=402,
detail={
"error_code": "premium_quota_exhausted",
"usage_type": exc.usage_type,
"used_micros": exc.used_micros,
"limit_micros": exc.limit_micros,
"remaining_micros": exc.remaining_micros,
"message": (
"Out of premium credits for image generation. "
"Purchase additional credits or switch to a free model."
),
},
) from exc
except SQLAlchemyError:
await session.rollback()
raise HTTPException(

View file

@ -1366,7 +1366,11 @@ async def append_message(
# flush assigns the PK/defaults without a round-trip SELECT
await session.flush()
# Persist token usage if provided (for assistant messages)
# Persist token usage if provided (for assistant messages).
# ``cost_micros`` is the provider USD cost reported by LiteLLM,
# forwarded by the FE through the appendMessage round-trip so
# the historical TokenUsage row matches the credit debit applied
# at finalize time.
token_usage_data = raw_body.get("token_usage")
if token_usage_data and message_role == NewChatMessageRole.ASSISTANT:
await record_token_usage(
@ -1377,6 +1381,7 @@ async def append_message(
prompt_tokens=token_usage_data.get("prompt_tokens", 0),
completion_tokens=token_usage_data.get("completion_tokens", 0),
total_tokens=token_usage_data.get("total_tokens", 0),
cost_micros=token_usage_data.get("cost_micros", 0),
model_breakdown=token_usage_data.get("usage"),
call_details=token_usage_data.get("call_details"),
thread_id=thread_id,

View file

@ -29,6 +29,7 @@ from app.schemas import (
NewLLMConfigUpdate,
)
from app.services.llm_service import validate_llm_config
from app.services.provider_capabilities import derive_supports_image_input
from app.users import current_active_user
from app.utils.rbac import check_permission
@ -36,6 +37,39 @@ router = APIRouter()
logger = logging.getLogger(__name__)
def _serialize_byok_config(config: NewLLMConfig) -> NewLLMConfigRead:
"""Augment a BYOK chat config row with the derived ``supports_image_input``.
There is no DB column for ``supports_image_input`` the value is
resolved at the API boundary from LiteLLM's authoritative model map
(default-allow on unknown). Returning ``NewLLMConfigRead`` here keeps
the response shape consistent across list / detail / create / update
endpoints without having to remember to set the field at every call
site.
"""
provider_value = (
config.provider.value
if hasattr(config.provider, "value")
else str(config.provider)
)
litellm_params = config.litellm_params or {}
base_model = (
litellm_params.get("base_model") if isinstance(litellm_params, dict) else None
)
supports_image_input = derive_supports_image_input(
provider=provider_value,
model_name=config.model_name,
base_model=base_model,
custom_provider=config.custom_provider,
)
# ``model_validate`` runs the Pydantic conversion using the ORM
# attribute access path enabled by ``ConfigDict(from_attributes=True)``,
# then we layer the derived field on. ``model_copy(update=...)`` keeps
# the surface immutable from the caller's perspective.
base_read = NewLLMConfigRead.model_validate(config)
return base_read.model_copy(update={"supports_image_input": supports_image_input})
# =============================================================================
# Global Configs Routes
# =============================================================================
@ -84,11 +118,41 @@ async def get_global_new_llm_configs(
"seo_title": None,
"seo_description": None,
"quota_reserve_tokens": None,
# Auto routes across the configured pool, which usually
# includes at least one vision-capable deployment, so
# treat Auto as image-capable. The router itself will
# still pick a vision-capable deployment for messages
# carrying image_url blocks (LiteLLM Router falls back
# on ``404`` per its ``allowed_fails`` policy).
"supports_image_input": True,
}
)
# Add individual global configs
for cfg in global_configs:
# Capability resolution: explicit value (YAML override or OR
# `_supports_image_input(model)` payload baked in by the
# OpenRouter integration service) wins. Fall back to the
# LiteLLM-driven helper which default-allows on unknown so
# we don't hide vision-capable models that happen to lack a
# YAML annotation. The streaming task safety net is the
# only place a False ever blocks.
if "supports_image_input" in cfg:
supports_image_input = bool(cfg.get("supports_image_input"))
else:
cfg_litellm_params = cfg.get("litellm_params") or {}
cfg_base_model = (
cfg_litellm_params.get("base_model")
if isinstance(cfg_litellm_params, dict)
else None
)
supports_image_input = derive_supports_image_input(
provider=cfg.get("provider"),
model_name=cfg.get("model_name"),
base_model=cfg_base_model,
custom_provider=cfg.get("custom_provider"),
)
safe_config = {
"id": cfg.get("id"),
"name": cfg.get("name"),
@ -113,6 +177,7 @@ async def get_global_new_llm_configs(
"seo_title": cfg.get("seo_title"),
"seo_description": cfg.get("seo_description"),
"quota_reserve_tokens": cfg.get("quota_reserve_tokens"),
"supports_image_input": supports_image_input,
}
safe_configs.append(safe_config)
@ -171,7 +236,7 @@ async def create_new_llm_config(
await session.commit()
await session.refresh(db_config)
return db_config
return _serialize_byok_config(db_config)
except HTTPException:
raise
@ -213,7 +278,7 @@ async def list_new_llm_configs(
.limit(limit)
)
return result.scalars().all()
return [_serialize_byok_config(cfg) for cfg in result.scalars().all()]
except HTTPException:
raise
@ -268,7 +333,7 @@ async def get_new_llm_config(
"You don't have permission to view LLM configurations in this search space",
)
return config
return _serialize_byok_config(config)
except HTTPException:
raise
@ -360,7 +425,7 @@ async def update_new_llm_config(
await session.commit()
await session.refresh(config)
return config
return _serialize_byok_config(config)
except HTTPException:
raise

View file

@ -591,6 +591,7 @@ async def _get_image_gen_config_by_id(
"model_name": "auto",
"is_global": True,
"is_auto_mode": True,
"billing_tier": "free",
}
if config_id < 0:
@ -607,6 +608,7 @@ async def _get_image_gen_config_by_id(
"api_version": cfg.get("api_version") or None,
"litellm_params": cfg.get("litellm_params", {}),
"is_global": True,
"billing_tier": cfg.get("billing_tier", "free"),
}
return None
@ -649,6 +651,7 @@ async def _get_vision_llm_config_by_id(
"model_name": "auto",
"is_global": True,
"is_auto_mode": True,
"billing_tier": "free",
}
if config_id < 0:
@ -665,6 +668,7 @@ async def _get_vision_llm_config_by_id(
"api_version": cfg.get("api_version") or None,
"litellm_params": cfg.get("litellm_params", {}),
"is_global": True,
"billing_tier": cfg.get("billing_tier", "free"),
}
return None

View file

@ -251,9 +251,16 @@ async def _fulfill_completed_token_purchase(
metadata = _get_metadata(checkout_session)
user_id = metadata.get("user_id")
quantity = int(metadata.get("quantity", "0"))
tokens_per_unit = int(metadata.get("tokens_per_unit", "0"))
# Read the new metadata key first, fall back to the legacy one so
# in-flight checkout sessions created before the cost-credits
# release still fulfil correctly (the unit is numerically the
# same: $1 buys 1_000_000 micro-USD == 1_000_000 tokens).
credit_micros_per_unit = int(
metadata.get("credit_micros_per_unit")
or metadata.get("tokens_per_unit", "0")
)
if not user_id or quantity <= 0 or tokens_per_unit <= 0:
if not user_id or quantity <= 0 or credit_micros_per_unit <= 0:
logger.error(
"Skipping token fulfillment for session %s: incomplete metadata %s",
checkout_session_id,
@ -268,7 +275,7 @@ async def _fulfill_completed_token_purchase(
getattr(checkout_session, "payment_intent", None)
),
quantity=quantity,
tokens_granted=quantity * tokens_per_unit,
credit_micros_granted=quantity * credit_micros_per_unit,
amount_total=getattr(checkout_session, "amount_total", None),
currency=getattr(checkout_session, "currency", None),
status=PremiumTokenPurchaseStatus.PENDING,
@ -303,9 +310,14 @@ async def _fulfill_completed_token_purchase(
purchase.stripe_payment_intent_id = _normalize_optional_string(
getattr(checkout_session, "payment_intent", None)
)
user.premium_tokens_limit = (
max(user.premium_tokens_used, user.premium_tokens_limit)
+ purchase.tokens_granted
# Top up the user's credit balance by the granted micro-USD amount.
# ``max(used, limit)`` clamps the case where the legacy code wrote a
# used value above the limit (e.g. underbilling rounding) so adding
# ``credit_micros_granted`` always lifts the limit by the full pack
# size rather than disappearing into past overuse.
user.premium_credit_micros_limit = (
max(user.premium_credit_micros_used, user.premium_credit_micros_limit)
+ purchase.credit_micros_granted
)
await db_session.commit()
@ -532,12 +544,18 @@ async def create_token_checkout_session(
user: User = Depends(current_active_user),
db_session: AsyncSession = Depends(get_async_session),
):
"""Create a Stripe Checkout Session for buying premium token packs."""
"""Create a Stripe Checkout Session for buying premium credit packs.
Each pack grants ``STRIPE_CREDIT_MICROS_PER_UNIT`` micro-USD of
credit (default 1_000_000 = $1.00). The user's balance is debited
at the actual provider cost reported by LiteLLM at finalize time,
so $1 of credit always buys $1 worth of provider usage at cost.
"""
_ensure_token_buying_enabled()
stripe_client = get_stripe_client()
price_id = _get_required_token_price_id()
success_url, cancel_url = _get_token_checkout_urls(body.search_space_id)
tokens_granted = body.quantity * config.STRIPE_TOKENS_PER_UNIT
credit_micros_granted = body.quantity * config.STRIPE_CREDIT_MICROS_PER_UNIT
try:
checkout_session = stripe_client.v1.checkout.sessions.create(
@ -556,8 +574,8 @@ async def create_token_checkout_session(
"metadata": {
"user_id": str(user.id),
"quantity": str(body.quantity),
"tokens_per_unit": str(config.STRIPE_TOKENS_PER_UNIT),
"purchase_type": "premium_tokens",
"credit_micros_per_unit": str(config.STRIPE_CREDIT_MICROS_PER_UNIT),
"purchase_type": "premium_credit",
},
}
)
@ -583,7 +601,7 @@ async def create_token_checkout_session(
getattr(checkout_session, "payment_intent", None)
),
quantity=body.quantity,
tokens_granted=tokens_granted,
credit_micros_granted=credit_micros_granted,
amount_total=getattr(checkout_session, "amount_total", None),
currency=getattr(checkout_session, "currency", None),
status=PremiumTokenPurchaseStatus.PENDING,
@ -598,14 +616,19 @@ async def create_token_checkout_session(
async def get_token_status(
user: User = Depends(current_active_user),
):
"""Return token-buying availability and current premium quota for frontend."""
used = user.premium_tokens_used
limit = user.premium_tokens_limit
"""Return token-buying availability and current premium credit quota for frontend.
Values are in micro-USD (1_000_000 = $1.00); the FE divides by 1M
when displaying. The route name is preserved for back-compat with
pinned client deployments.
"""
used = user.premium_credit_micros_used
limit = user.premium_credit_micros_limit
return TokenStripeStatusResponse(
token_buying_enabled=config.STRIPE_TOKEN_BUYING_ENABLED,
premium_tokens_used=used,
premium_tokens_limit=limit,
premium_tokens_remaining=max(0, limit - used),
premium_credit_micros_used=used,
premium_credit_micros_limit=limit,
premium_credit_micros_remaining=max(0, limit - used),
)

View file

@ -82,10 +82,15 @@ async def get_global_vision_llm_configs(
"litellm_params": {},
"is_global": True,
"is_auto_mode": True,
# Auto mode treated as free until per-deployment billing-tier
# surfacing lands; see ``get_vision_llm`` for parity.
"billing_tier": "free",
"is_premium": False,
}
)
for cfg in global_configs:
billing_tier = str(cfg.get("billing_tier", "free")).lower()
safe_configs.append(
{
"id": cfg.get("id"),
@ -98,6 +103,14 @@ async def get_global_vision_llm_configs(
"api_version": cfg.get("api_version") or None,
"litellm_params": cfg.get("litellm_params", {}),
"is_global": True,
"billing_tier": billing_tier,
# Mirror chat (``new_llm_config_routes``) so the new-chat
# selector's premium badge logic keys off the same
# field across chat / image / vision tabs.
"is_premium": billing_tier == "premium",
"quota_reserve_tokens": cfg.get("quota_reserve_tokens"),
"input_cost_per_token": cfg.get("input_cost_per_token"),
"output_cost_per_token": cfg.get("output_cost_per_token"),
}
)