mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-17 18:35:19 +02:00
feat(chat): implement forced repin to free tier for pinned LLM configurations
This commit is contained in:
parent
c598d7038f
commit
d66fa1559b
3 changed files with 200 additions and 64 deletions
|
|
@ -84,6 +84,7 @@ async def resolve_or_get_pinned_llm_config_id(
|
|||
search_space_id: int,
|
||||
user_id: str | UUID | None,
|
||||
selected_llm_config_id: int,
|
||||
force_repin_free: bool = False,
|
||||
) -> AutoPinResolution:
|
||||
"""Resolve Auto (Fastest) to one concrete config id and persist pin metadata.
|
||||
|
||||
|
|
@ -130,9 +131,12 @@ async def resolve_or_get_pinned_llm_config_id(
|
|||
raise ValueError("No usable global LLM configs are available for Auto mode")
|
||||
candidate_by_id = {int(c["id"]): c for c in candidates}
|
||||
|
||||
# Reuse existing valid pin without re-checking current quota (no silent tier switch).
|
||||
# Reuse existing valid pin without re-checking current quota (no silent tier switch),
|
||||
# unless the caller explicitly requests a forced repin to free.
|
||||
pinned_id = thread.pinned_llm_config_id
|
||||
if (
|
||||
not force_repin_free
|
||||
and
|
||||
thread.pinned_auto_mode == AUTO_FASTEST_MODE
|
||||
and pinned_id is not None
|
||||
and int(pinned_id) in candidate_by_id
|
||||
|
|
@ -159,7 +163,7 @@ async def resolve_or_get_pinned_llm_config_id(
|
|||
thread.pinned_auto_mode,
|
||||
)
|
||||
|
||||
premium_eligible = await _is_premium_eligible(session, user_id)
|
||||
premium_eligible = False if force_repin_free else await _is_premium_eligible(session, user_id)
|
||||
if premium_eligible:
|
||||
eligible = candidates
|
||||
else:
|
||||
|
|
@ -179,6 +183,15 @@ async def resolve_or_get_pinned_llm_config_id(
|
|||
thread.pinned_at = datetime.now(UTC)
|
||||
await session.commit()
|
||||
|
||||
if force_repin_free:
|
||||
logger.info(
|
||||
"auto_pin_forced_free_repin thread_id=%s search_space_id=%s previous_config_id=%s resolved_config_id=%s",
|
||||
thread_id,
|
||||
search_space_id,
|
||||
pinned_id,
|
||||
selected_id,
|
||||
)
|
||||
|
||||
if pinned_id is None:
|
||||
logger.info(
|
||||
"auto_pin_created thread_id=%s search_space_id=%s resolved_config_id=%s tier=%s premium_eligible=%s",
|
||||
|
|
|
|||
|
|
@ -1455,6 +1455,37 @@ async def stream_new_chat(
|
|||
await set_ai_responding(session, chat_id, UUID(user_id))
|
||||
# Load LLM config - supports both YAML (negative IDs) and database (positive IDs)
|
||||
agent_config: AgentConfig | None = None
|
||||
requested_llm_config_id = llm_config_id
|
||||
|
||||
async def _load_llm_bundle(
|
||||
config_id: int,
|
||||
) -> tuple[Any, AgentConfig | None, str | None]:
|
||||
if config_id >= 0:
|
||||
loaded_agent_config = await load_agent_config(
|
||||
session=session,
|
||||
config_id=config_id,
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
if not loaded_agent_config:
|
||||
return (
|
||||
None,
|
||||
None,
|
||||
f"Failed to load NewLLMConfig with id {config_id}",
|
||||
)
|
||||
return (
|
||||
create_chat_litellm_from_agent_config(loaded_agent_config),
|
||||
loaded_agent_config,
|
||||
None,
|
||||
)
|
||||
|
||||
loaded_llm_config = load_global_llm_config_by_id(config_id)
|
||||
if not loaded_llm_config:
|
||||
return None, None, f"Failed to load LLM config with id {config_id}"
|
||||
return (
|
||||
create_chat_litellm_from_config(loaded_llm_config),
|
||||
AgentConfig.from_yaml_config(loaded_llm_config),
|
||||
None,
|
||||
)
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
try:
|
||||
|
|
@ -1472,35 +1503,11 @@ async def stream_new_chat(
|
|||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
if llm_config_id >= 0:
|
||||
# Positive ID: Load from NewLLMConfig database table
|
||||
agent_config = await load_agent_config(
|
||||
session=session,
|
||||
config_id=llm_config_id,
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
if not agent_config:
|
||||
yield streaming_service.format_error(
|
||||
f"Failed to load NewLLMConfig with id {llm_config_id}"
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
# Create ChatLiteLLM from AgentConfig
|
||||
llm = create_chat_litellm_from_agent_config(agent_config)
|
||||
else:
|
||||
# Negative ID: Load from in-memory global configs (includes dynamic OpenRouter models)
|
||||
llm_config = load_global_llm_config_by_id(llm_config_id)
|
||||
if not llm_config:
|
||||
yield streaming_service.format_error(
|
||||
f"Failed to load LLM config with id {llm_config_id}"
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
# Create ChatLiteLLM from global config dict
|
||||
llm = create_chat_litellm_from_config(llm_config)
|
||||
agent_config = AgentConfig.from_yaml_config(llm_config)
|
||||
llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id)
|
||||
if llm_load_error:
|
||||
yield streaming_service.format_error(llm_load_error)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] LLM config loaded in %.3fs (config_id=%s)",
|
||||
time.perf_counter() - _t0,
|
||||
|
|
@ -1541,11 +1548,43 @@ async def stream_new_chat(
|
|||
user_id,
|
||||
llm_config_id,
|
||||
)
|
||||
yield streaming_service.format_error(
|
||||
"Premium tokens exhausted. Buy more tokens to continue with this model, or switch to a free model."
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
if requested_llm_config_id == 0:
|
||||
try:
|
||||
llm_config_id = (
|
||||
await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
selected_llm_config_id=0,
|
||||
force_repin_free=True,
|
||||
)
|
||||
).resolved_llm_config_id
|
||||
except ValueError as pin_error:
|
||||
yield streaming_service.format_error(str(pin_error))
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id)
|
||||
if llm_load_error:
|
||||
yield streaming_service.format_error(llm_load_error)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
_premium_request_id = None
|
||||
_premium_reserved = 0
|
||||
logging.getLogger(__name__).info(
|
||||
"premium_quota_auto_fallback_to_free thread_id=%s search_space_id=%s user_id=%s fallback_config_id=%s",
|
||||
chat_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
llm_config_id,
|
||||
)
|
||||
else:
|
||||
yield streaming_service.format_error(
|
||||
"Premium tokens exhausted. Buy more tokens to continue with this model, or switch to a free model."
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
if not llm:
|
||||
yield streaming_service.format_error("Failed to create LLM instance")
|
||||
|
|
@ -2183,6 +2222,38 @@ async def stream_resume_chat(
|
|||
await set_ai_responding(session, chat_id, UUID(user_id))
|
||||
|
||||
agent_config: AgentConfig | None = None
|
||||
requested_llm_config_id = llm_config_id
|
||||
|
||||
async def _load_llm_bundle(
|
||||
config_id: int,
|
||||
) -> tuple[Any, AgentConfig | None, str | None]:
|
||||
if config_id >= 0:
|
||||
loaded_agent_config = await load_agent_config(
|
||||
session=session,
|
||||
config_id=config_id,
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
if not loaded_agent_config:
|
||||
return (
|
||||
None,
|
||||
None,
|
||||
f"Failed to load NewLLMConfig with id {config_id}",
|
||||
)
|
||||
return (
|
||||
create_chat_litellm_from_agent_config(loaded_agent_config),
|
||||
loaded_agent_config,
|
||||
None,
|
||||
)
|
||||
|
||||
loaded_llm_config = load_global_llm_config_by_id(config_id)
|
||||
if not loaded_llm_config:
|
||||
return None, None, f"Failed to load LLM config with id {config_id}"
|
||||
return (
|
||||
create_chat_litellm_from_config(loaded_llm_config),
|
||||
AgentConfig.from_yaml_config(loaded_llm_config),
|
||||
None,
|
||||
)
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
try:
|
||||
llm_config_id = (
|
||||
|
|
@ -2199,29 +2270,11 @@ async def stream_resume_chat(
|
|||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
if llm_config_id >= 0:
|
||||
agent_config = await load_agent_config(
|
||||
session=session,
|
||||
config_id=llm_config_id,
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
if not agent_config:
|
||||
yield streaming_service.format_error(
|
||||
f"Failed to load NewLLMConfig with id {llm_config_id}"
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
llm = create_chat_litellm_from_agent_config(agent_config)
|
||||
else:
|
||||
llm_config = load_global_llm_config_by_id(llm_config_id)
|
||||
if not llm_config:
|
||||
yield streaming_service.format_error(
|
||||
f"Failed to load LLM config with id {llm_config_id}"
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
llm = create_chat_litellm_from_config(llm_config)
|
||||
agent_config = AgentConfig.from_yaml_config(llm_config)
|
||||
llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id)
|
||||
if llm_load_error:
|
||||
yield streaming_service.format_error(llm_load_error)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
_perf_log.info(
|
||||
"[stream_resume] LLM config loaded in %.3fs", time.perf_counter() - _t0
|
||||
)
|
||||
|
|
@ -2262,11 +2315,43 @@ async def stream_resume_chat(
|
|||
user_id,
|
||||
llm_config_id,
|
||||
)
|
||||
yield streaming_service.format_error(
|
||||
"Premium tokens exhausted. Buy more tokens to continue with this model, or switch to a free model."
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
if requested_llm_config_id == 0:
|
||||
try:
|
||||
llm_config_id = (
|
||||
await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
selected_llm_config_id=0,
|
||||
force_repin_free=True,
|
||||
)
|
||||
).resolved_llm_config_id
|
||||
except ValueError as pin_error:
|
||||
yield streaming_service.format_error(str(pin_error))
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id)
|
||||
if llm_load_error:
|
||||
yield streaming_service.format_error(llm_load_error)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
_resume_premium_request_id = None
|
||||
_resume_premium_reserved = 0
|
||||
logging.getLogger(__name__).info(
|
||||
"premium_quota_auto_fallback_to_free thread_id=%s search_space_id=%s user_id=%s fallback_config_id=%s",
|
||||
chat_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
llm_config_id,
|
||||
)
|
||||
else:
|
||||
yield streaming_service.format_error(
|
||||
"Premium tokens exhausted. Buy more tokens to continue with this model, or switch to a free model."
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
if not llm:
|
||||
yield streaming_service.format_error("Failed to create LLM instance")
|
||||
|
|
|
|||
|
|
@ -227,6 +227,44 @@ async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch):
|
|||
assert result.from_existing_pin is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_force_repin_free_switches_auto_premium_pin_to_free(monkeypatch):
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(
|
||||
_thread(pinned_llm_config_id=-1, pinned_auto_mode=AUTO_FASTEST_MODE)
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free"},
|
||||
{"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"},
|
||||
],
|
||||
)
|
||||
|
||||
async def _blocked(*_args, **_kwargs):
|
||||
return _FakeQuotaResult(allowed=False)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_blocked,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id="00000000-0000-0000-0000-000000000001",
|
||||
selected_llm_config_id=0,
|
||||
force_repin_free=True,
|
||||
)
|
||||
assert result.resolved_llm_config_id == -2
|
||||
assert result.resolved_tier == "free"
|
||||
assert result.from_existing_pin is False
|
||||
assert session.thread.pinned_llm_config_id == -2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_explicit_user_model_change_clears_pin(monkeypatch):
|
||||
from app.config import config
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue