mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-03 21:02:40 +02:00
fix(chat): enforce pinned model quota flow and reset stale pins
This commit is contained in:
parent
41849fe10f
commit
835bd9f65d
2 changed files with 88 additions and 44 deletions
|
|
@ -3,7 +3,7 @@ import logging
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from langchain_core.messages import HumanMessage
|
from langchain_core.messages import HumanMessage
|
||||||
from pydantic import BaseModel as PydanticBaseModel
|
from pydantic import BaseModel as PydanticBaseModel
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func, update
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
|
|
@ -15,6 +15,7 @@ from app.agents.new_chat.tools.update_memory import MEMORY_HARD_LIMIT, _save_mem
|
||||||
from app.config import config
|
from app.config import config
|
||||||
from app.db import (
|
from app.db import (
|
||||||
ImageGenerationConfig,
|
ImageGenerationConfig,
|
||||||
|
NewChatThread,
|
||||||
NewLLMConfig,
|
NewLLMConfig,
|
||||||
Permission,
|
Permission,
|
||||||
SearchSpace,
|
SearchSpace,
|
||||||
|
|
@ -790,9 +791,31 @@ async def update_llm_preferences(
|
||||||
|
|
||||||
# Update preferences
|
# Update preferences
|
||||||
update_data = preferences.model_dump(exclude_unset=True)
|
update_data = preferences.model_dump(exclude_unset=True)
|
||||||
|
previous_agent_llm_id = search_space.agent_llm_id
|
||||||
for key, value in update_data.items():
|
for key, value in update_data.items():
|
||||||
setattr(search_space, key, value)
|
setattr(search_space, key, value)
|
||||||
|
|
||||||
|
agent_llm_changed = (
|
||||||
|
"agent_llm_id" in update_data
|
||||||
|
and update_data["agent_llm_id"] != previous_agent_llm_id
|
||||||
|
)
|
||||||
|
if agent_llm_changed:
|
||||||
|
await session.execute(
|
||||||
|
update(NewChatThread)
|
||||||
|
.where(NewChatThread.search_space_id == search_space_id)
|
||||||
|
.values(
|
||||||
|
pinned_llm_config_id=None,
|
||||||
|
pinned_auto_mode=None,
|
||||||
|
pinned_at=None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"Cleared auto model pins for search_space_id=%s after agent_llm_id change (%s -> %s)",
|
||||||
|
search_space_id,
|
||||||
|
previous_agent_llm_id,
|
||||||
|
update_data["agent_llm_id"],
|
||||||
|
)
|
||||||
|
|
||||||
await session.commit()
|
await session.commit()
|
||||||
await session.refresh(search_space)
|
await session.refresh(search_space)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -56,6 +56,7 @@ from app.db import (
|
||||||
shielded_async_session,
|
shielded_async_session,
|
||||||
)
|
)
|
||||||
from app.prompts import TITLE_GENERATION_PROMPT
|
from app.prompts import TITLE_GENERATION_PROMPT
|
||||||
|
from app.services.auto_model_pin_service import resolve_or_get_pinned_llm_config_id
|
||||||
from app.services.chat_session_state_service import (
|
from app.services.chat_session_state_service import (
|
||||||
clear_ai_responding,
|
clear_ai_responding,
|
||||||
set_ai_responding,
|
set_ai_responding,
|
||||||
|
|
@ -1456,6 +1457,21 @@ async def stream_new_chat(
|
||||||
agent_config: AgentConfig | None = None
|
agent_config: AgentConfig | None = None
|
||||||
|
|
||||||
_t0 = time.perf_counter()
|
_t0 = time.perf_counter()
|
||||||
|
try:
|
||||||
|
llm_config_id = (
|
||||||
|
await resolve_or_get_pinned_llm_config_id(
|
||||||
|
session,
|
||||||
|
thread_id=chat_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
selected_llm_config_id=llm_config_id,
|
||||||
|
)
|
||||||
|
).resolved_llm_config_id
|
||||||
|
except ValueError as pin_error:
|
||||||
|
yield streaming_service.format_error(str(pin_error))
|
||||||
|
yield streaming_service.format_done()
|
||||||
|
return
|
||||||
|
|
||||||
if llm_config_id >= 0:
|
if llm_config_id >= 0:
|
||||||
# Positive ID: Load from NewLLMConfig database table
|
# Positive ID: Load from NewLLMConfig database table
|
||||||
agent_config = await load_agent_config(
|
agent_config = await load_agent_config(
|
||||||
|
|
@ -1491,12 +1507,11 @@ async def stream_new_chat(
|
||||||
llm_config_id,
|
llm_config_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Premium quota reservation — applies to explicitly premium configs
|
# Premium quota reservation for pinned premium model only.
|
||||||
# AND Auto mode (which may route to premium models).
|
|
||||||
_needs_premium_quota = (
|
_needs_premium_quota = (
|
||||||
agent_config is not None
|
agent_config is not None
|
||||||
and user_id
|
and user_id
|
||||||
and (agent_config.is_premium or agent_config.is_auto_mode)
|
and agent_config.is_premium
|
||||||
)
|
)
|
||||||
if _needs_premium_quota:
|
if _needs_premium_quota:
|
||||||
import uuid as _uuid
|
import uuid as _uuid
|
||||||
|
|
@ -1519,16 +1534,18 @@ async def stream_new_chat(
|
||||||
)
|
)
|
||||||
_premium_reserved = reserve_amount
|
_premium_reserved = reserve_amount
|
||||||
if not quota_result.allowed:
|
if not quota_result.allowed:
|
||||||
if agent_config.is_premium:
|
logging.getLogger(__name__).info(
|
||||||
yield streaming_service.format_error(
|
"premium_quota_blocked_pinned_model thread_id=%s search_space_id=%s user_id=%s resolved_config_id=%s",
|
||||||
"Premium token quota exceeded. Please purchase more tokens to continue using premium models."
|
chat_id,
|
||||||
)
|
search_space_id,
|
||||||
yield streaming_service.format_done()
|
user_id,
|
||||||
return
|
llm_config_id,
|
||||||
# Auto mode: quota exhausted but we can still proceed
|
)
|
||||||
# (the router may pick a free model). Reset reservation.
|
yield streaming_service.format_error(
|
||||||
_premium_request_id = None
|
"Premium token quota exceeded for this pinned model. Select a free model or re-select Auto (Fastest) to repin."
|
||||||
_premium_reserved = 0
|
)
|
||||||
|
yield streaming_service.format_done()
|
||||||
|
return
|
||||||
|
|
||||||
if not llm:
|
if not llm:
|
||||||
yield streaming_service.format_error("Failed to create LLM instance")
|
yield streaming_service.format_error("Failed to create LLM instance")
|
||||||
|
|
@ -1961,28 +1978,20 @@ async def stream_new_chat(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Finalize premium quota with actual tokens.
|
# Finalize premium quota with actual tokens.
|
||||||
# For Auto mode, only count tokens from calls that used premium models.
|
|
||||||
if _premium_request_id and user_id:
|
if _premium_request_id and user_id:
|
||||||
try:
|
try:
|
||||||
from app.services.token_quota_service import TokenQuotaService
|
from app.services.token_quota_service import TokenQuotaService
|
||||||
|
|
||||||
if agent_config and agent_config.is_auto_mode:
|
|
||||||
from app.services.llm_router_service import LLMRouterService
|
|
||||||
|
|
||||||
actual_premium_tokens = LLMRouterService.compute_premium_tokens(
|
|
||||||
accumulator.calls
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
actual_premium_tokens = accumulator.grand_total
|
|
||||||
|
|
||||||
async with shielded_async_session() as quota_session:
|
async with shielded_async_session() as quota_session:
|
||||||
await TokenQuotaService.premium_finalize(
|
await TokenQuotaService.premium_finalize(
|
||||||
db_session=quota_session,
|
db_session=quota_session,
|
||||||
user_id=UUID(user_id),
|
user_id=UUID(user_id),
|
||||||
request_id=_premium_request_id,
|
request_id=_premium_request_id,
|
||||||
actual_tokens=actual_premium_tokens,
|
actual_tokens=accumulator.grand_total,
|
||||||
reserved_tokens=_premium_reserved,
|
reserved_tokens=_premium_reserved,
|
||||||
)
|
)
|
||||||
|
_premium_request_id = None
|
||||||
|
_premium_reserved = 0
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.getLogger(__name__).warning(
|
logging.getLogger(__name__).warning(
|
||||||
"Failed to finalize premium quota for user %s",
|
"Failed to finalize premium quota for user %s",
|
||||||
|
|
@ -2175,6 +2184,21 @@ async def stream_resume_chat(
|
||||||
|
|
||||||
agent_config: AgentConfig | None = None
|
agent_config: AgentConfig | None = None
|
||||||
_t0 = time.perf_counter()
|
_t0 = time.perf_counter()
|
||||||
|
try:
|
||||||
|
llm_config_id = (
|
||||||
|
await resolve_or_get_pinned_llm_config_id(
|
||||||
|
session,
|
||||||
|
thread_id=chat_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
selected_llm_config_id=llm_config_id,
|
||||||
|
)
|
||||||
|
).resolved_llm_config_id
|
||||||
|
except ValueError as pin_error:
|
||||||
|
yield streaming_service.format_error(str(pin_error))
|
||||||
|
yield streaming_service.format_done()
|
||||||
|
return
|
||||||
|
|
||||||
if llm_config_id >= 0:
|
if llm_config_id >= 0:
|
||||||
agent_config = await load_agent_config(
|
agent_config = await load_agent_config(
|
||||||
session=session,
|
session=session,
|
||||||
|
|
@ -2208,7 +2232,7 @@ async def stream_resume_chat(
|
||||||
_resume_needs_premium = (
|
_resume_needs_premium = (
|
||||||
agent_config is not None
|
agent_config is not None
|
||||||
and user_id
|
and user_id
|
||||||
and (agent_config.is_premium or agent_config.is_auto_mode)
|
and agent_config.is_premium
|
||||||
)
|
)
|
||||||
if _resume_needs_premium:
|
if _resume_needs_premium:
|
||||||
import uuid as _uuid
|
import uuid as _uuid
|
||||||
|
|
@ -2231,14 +2255,18 @@ async def stream_resume_chat(
|
||||||
)
|
)
|
||||||
_resume_premium_reserved = reserve_amount
|
_resume_premium_reserved = reserve_amount
|
||||||
if not quota_result.allowed:
|
if not quota_result.allowed:
|
||||||
if agent_config.is_premium:
|
logging.getLogger(__name__).info(
|
||||||
yield streaming_service.format_error(
|
"premium_quota_blocked_pinned_model thread_id=%s search_space_id=%s user_id=%s resolved_config_id=%s",
|
||||||
"Premium token quota exceeded. Please purchase more tokens to continue using premium models."
|
chat_id,
|
||||||
)
|
search_space_id,
|
||||||
yield streaming_service.format_done()
|
user_id,
|
||||||
return
|
llm_config_id,
|
||||||
_resume_premium_request_id = None
|
)
|
||||||
_resume_premium_reserved = 0
|
yield streaming_service.format_error(
|
||||||
|
"Premium token quota exceeded for this pinned model. Select a free model or re-select Auto (Fastest) to repin."
|
||||||
|
)
|
||||||
|
yield streaming_service.format_done()
|
||||||
|
return
|
||||||
|
|
||||||
if not llm:
|
if not llm:
|
||||||
yield streaming_service.format_error("Failed to create LLM instance")
|
yield streaming_service.format_error("Failed to create LLM instance")
|
||||||
|
|
@ -2370,23 +2398,16 @@ async def stream_resume_chat(
|
||||||
try:
|
try:
|
||||||
from app.services.token_quota_service import TokenQuotaService
|
from app.services.token_quota_service import TokenQuotaService
|
||||||
|
|
||||||
if agent_config and agent_config.is_auto_mode:
|
|
||||||
from app.services.llm_router_service import LLMRouterService
|
|
||||||
|
|
||||||
actual_premium_tokens = LLMRouterService.compute_premium_tokens(
|
|
||||||
accumulator.calls
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
actual_premium_tokens = accumulator.grand_total
|
|
||||||
|
|
||||||
async with shielded_async_session() as quota_session:
|
async with shielded_async_session() as quota_session:
|
||||||
await TokenQuotaService.premium_finalize(
|
await TokenQuotaService.premium_finalize(
|
||||||
db_session=quota_session,
|
db_session=quota_session,
|
||||||
user_id=UUID(user_id),
|
user_id=UUID(user_id),
|
||||||
request_id=_resume_premium_request_id,
|
request_id=_resume_premium_request_id,
|
||||||
actual_tokens=actual_premium_tokens,
|
actual_tokens=accumulator.grand_total,
|
||||||
reserved_tokens=_resume_premium_reserved,
|
reserved_tokens=_resume_premium_reserved,
|
||||||
)
|
)
|
||||||
|
_resume_premium_request_id = None
|
||||||
|
_resume_premium_reserved = 0
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.getLogger(__name__).warning(
|
logging.getLogger(__name__).warning(
|
||||||
"Failed to finalize premium quota for user %s (resume)",
|
"Failed to finalize premium quota for user %s (resume)",
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue