mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-04 05:12:38 +02:00
feat: unified credits and its cost calculations
This commit is contained in:
parent
451a98936e
commit
ae9d36d77f
61 changed files with 5835 additions and 272 deletions
|
|
@ -9,7 +9,13 @@ from sqlalchemy import select
|
|||
from app.agents.podcaster.graph import graph as podcaster_graph
|
||||
from app.agents.podcaster.state import State as PodcasterState
|
||||
from app.celery_app import celery_app
|
||||
from app.config import config as app_config
|
||||
from app.db import Podcast, PodcastStatus
|
||||
from app.services.billable_calls import (
|
||||
QuotaInsufficientError,
|
||||
_resolve_agent_billing_for_search_space,
|
||||
billable_call,
|
||||
)
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -96,6 +102,31 @@ async def _generate_content_podcast(
|
|||
podcast.status = PodcastStatus.GENERATING
|
||||
await session.commit()
|
||||
|
||||
try:
|
||||
(
|
||||
owner_user_id,
|
||||
billing_tier,
|
||||
base_model,
|
||||
) = await _resolve_agent_billing_for_search_space(
|
||||
session,
|
||||
search_space_id,
|
||||
thread_id=podcast.thread_id,
|
||||
)
|
||||
except ValueError as resolve_err:
|
||||
logger.error(
|
||||
"Podcast %s: cannot resolve billing for search_space=%s: %s",
|
||||
podcast.id,
|
||||
search_space_id,
|
||||
resolve_err,
|
||||
)
|
||||
podcast.status = PodcastStatus.FAILED
|
||||
await session.commit()
|
||||
return {
|
||||
"status": "failed",
|
||||
"podcast_id": podcast.id,
|
||||
"reason": "billing_resolution_failed",
|
||||
}
|
||||
|
||||
graph_config = {
|
||||
"configurable": {
|
||||
"podcast_title": podcast.title,
|
||||
|
|
@ -109,9 +140,39 @@ async def _generate_content_podcast(
|
|||
db_session=session,
|
||||
)
|
||||
|
||||
graph_result = await podcaster_graph.ainvoke(
|
||||
initial_state, config=graph_config
|
||||
)
|
||||
try:
|
||||
async with billable_call(
|
||||
user_id=owner_user_id,
|
||||
search_space_id=search_space_id,
|
||||
billing_tier=billing_tier,
|
||||
base_model=base_model,
|
||||
quota_reserve_micros_override=app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS,
|
||||
usage_type="podcast_generation",
|
||||
thread_id=podcast.thread_id,
|
||||
call_details={
|
||||
"podcast_id": podcast.id,
|
||||
"title": podcast.title,
|
||||
},
|
||||
):
|
||||
graph_result = await podcaster_graph.ainvoke(
|
||||
initial_state, config=graph_config
|
||||
)
|
||||
except QuotaInsufficientError as exc:
|
||||
logger.info(
|
||||
"Podcast %s denied: out of premium credits "
|
||||
"(used=%d/%d remaining=%d)",
|
||||
podcast.id,
|
||||
exc.used_micros,
|
||||
exc.limit_micros,
|
||||
exc.remaining_micros,
|
||||
)
|
||||
podcast.status = PodcastStatus.FAILED
|
||||
await session.commit()
|
||||
return {
|
||||
"status": "failed",
|
||||
"podcast_id": podcast.id,
|
||||
"reason": "premium_quota_exhausted",
|
||||
}
|
||||
|
||||
podcast_transcript = graph_result.get("podcast_transcript", [])
|
||||
file_path = graph_result.get("final_podcast_file_path", "")
|
||||
|
|
|
|||
|
|
@ -9,7 +9,13 @@ from sqlalchemy import select
|
|||
from app.agents.video_presentation.graph import graph as video_presentation_graph
|
||||
from app.agents.video_presentation.state import State as VideoPresentationState
|
||||
from app.celery_app import celery_app
|
||||
from app.config import config as app_config
|
||||
from app.db import VideoPresentation, VideoPresentationStatus
|
||||
from app.services.billable_calls import (
|
||||
QuotaInsufficientError,
|
||||
_resolve_agent_billing_for_search_space,
|
||||
billable_call,
|
||||
)
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -97,6 +103,32 @@ async def _generate_video_presentation(
|
|||
video_pres.status = VideoPresentationStatus.GENERATING
|
||||
await session.commit()
|
||||
|
||||
try:
|
||||
(
|
||||
owner_user_id,
|
||||
billing_tier,
|
||||
base_model,
|
||||
) = await _resolve_agent_billing_for_search_space(
|
||||
session,
|
||||
search_space_id,
|
||||
thread_id=video_pres.thread_id,
|
||||
)
|
||||
except ValueError as resolve_err:
|
||||
logger.error(
|
||||
"VideoPresentation %s: cannot resolve billing for "
|
||||
"search_space=%s: %s",
|
||||
video_pres.id,
|
||||
search_space_id,
|
||||
resolve_err,
|
||||
)
|
||||
video_pres.status = VideoPresentationStatus.FAILED
|
||||
await session.commit()
|
||||
return {
|
||||
"status": "failed",
|
||||
"video_presentation_id": video_pres.id,
|
||||
"reason": "billing_resolution_failed",
|
||||
}
|
||||
|
||||
graph_config = {
|
||||
"configurable": {
|
||||
"video_title": video_pres.title,
|
||||
|
|
@ -110,9 +142,39 @@ async def _generate_video_presentation(
|
|||
db_session=session,
|
||||
)
|
||||
|
||||
graph_result = await video_presentation_graph.ainvoke(
|
||||
initial_state, config=graph_config
|
||||
)
|
||||
try:
|
||||
async with billable_call(
|
||||
user_id=owner_user_id,
|
||||
search_space_id=search_space_id,
|
||||
billing_tier=billing_tier,
|
||||
base_model=base_model,
|
||||
quota_reserve_micros_override=app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS,
|
||||
usage_type="video_presentation_generation",
|
||||
thread_id=video_pres.thread_id,
|
||||
call_details={
|
||||
"video_presentation_id": video_pres.id,
|
||||
"title": video_pres.title,
|
||||
},
|
||||
):
|
||||
graph_result = await video_presentation_graph.ainvoke(
|
||||
initial_state, config=graph_config
|
||||
)
|
||||
except QuotaInsufficientError as exc:
|
||||
logger.info(
|
||||
"VideoPresentation %s denied: out of premium credits "
|
||||
"(used=%d/%d remaining=%d)",
|
||||
video_pres.id,
|
||||
exc.used_micros,
|
||||
exc.limit_micros,
|
||||
exc.remaining_micros,
|
||||
)
|
||||
video_pres.status = VideoPresentationStatus.FAILED
|
||||
await session.commit()
|
||||
return {
|
||||
"status": "failed",
|
||||
"video_presentation_id": video_pres.id,
|
||||
"reason": "premium_quota_exhausted",
|
||||
}
|
||||
|
||||
# Serialize slides (parsed content + audio info merged)
|
||||
slides_raw = graph_result.get("slides", [])
|
||||
|
|
|
|||
|
|
@ -2236,8 +2236,10 @@ async def stream_new_chat(
|
|||
|
||||
accumulator = start_turn()
|
||||
|
||||
# Premium quota tracking state
|
||||
_premium_reserved = 0
|
||||
# Premium credit (USD micro-units) tracking state. Stores the
|
||||
# amount reserved up front so we can release it on cancellation
|
||||
# and finalize-debit the actual provider cost reported by LiteLLM.
|
||||
_premium_reserved_micros = 0
|
||||
_premium_request_id: str | None = None
|
||||
|
||||
_emit_stream_error = partial(
|
||||
|
|
@ -2331,23 +2333,28 @@ async def stream_new_chat(
|
|||
if _needs_premium_quota:
|
||||
import uuid as _uuid
|
||||
|
||||
from app.config import config as _app_config
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
from app.services.token_quota_service import (
|
||||
TokenQuotaService,
|
||||
estimate_call_reserve_micros,
|
||||
)
|
||||
|
||||
_premium_request_id = _uuid.uuid4().hex[:16]
|
||||
reserve_amount = min(
|
||||
agent_config.quota_reserve_tokens
|
||||
or _app_config.QUOTA_MAX_RESERVE_PER_CALL,
|
||||
_app_config.QUOTA_MAX_RESERVE_PER_CALL,
|
||||
_agent_litellm_params = agent_config.litellm_params or {}
|
||||
_agent_base_model = (
|
||||
_agent_litellm_params.get("base_model") or agent_config.model_name or ""
|
||||
)
|
||||
reserve_amount_micros = estimate_call_reserve_micros(
|
||||
base_model=_agent_base_model,
|
||||
quota_reserve_tokens=agent_config.quota_reserve_tokens,
|
||||
)
|
||||
async with shielded_async_session() as quota_session:
|
||||
quota_result = await TokenQuotaService.premium_reserve(
|
||||
db_session=quota_session,
|
||||
user_id=UUID(user_id),
|
||||
request_id=_premium_request_id,
|
||||
reserve_tokens=reserve_amount,
|
||||
reserve_micros=reserve_amount_micros,
|
||||
)
|
||||
_premium_reserved = reserve_amount
|
||||
_premium_reserved_micros = reserve_amount_micros
|
||||
if not quota_result.allowed:
|
||||
if requested_llm_config_id == 0:
|
||||
try:
|
||||
|
|
@ -2382,7 +2389,7 @@ async def stream_new_chat(
|
|||
yield streaming_service.format_done()
|
||||
return
|
||||
_premium_request_id = None
|
||||
_premium_reserved = 0
|
||||
_premium_reserved_micros = 0
|
||||
_log_chat_stream_error(
|
||||
flow=flow,
|
||||
error_kind="premium_quota_exhausted",
|
||||
|
|
@ -3020,9 +3027,10 @@ async def stream_new_chat(
|
|||
|
||||
usage_summary = accumulator.per_message_summary()
|
||||
_perf_log.info(
|
||||
"[token_usage] interrupted new_chat: calls=%d total=%d summary=%s",
|
||||
"[token_usage] interrupted new_chat: calls=%d total=%d cost_micros=%d summary=%s",
|
||||
len(accumulator.calls),
|
||||
accumulator.grand_total,
|
||||
accumulator.total_cost_micros,
|
||||
usage_summary,
|
||||
)
|
||||
if usage_summary:
|
||||
|
|
@ -3033,6 +3041,7 @@ async def stream_new_chat(
|
|||
"prompt_tokens": accumulator.total_prompt_tokens,
|
||||
"completion_tokens": accumulator.total_completion_tokens,
|
||||
"total_tokens": accumulator.grand_total,
|
||||
"cost_micros": accumulator.total_cost_micros,
|
||||
"call_details": accumulator.serialized_calls(),
|
||||
},
|
||||
)
|
||||
|
|
@ -3060,7 +3069,11 @@ async def stream_new_chat(
|
|||
chat_id, generated_title
|
||||
)
|
||||
|
||||
# Finalize premium quota with actual tokens.
|
||||
# Finalize premium credit debit with the actual provider cost
|
||||
# reported by LiteLLM, summed across every call in the turn.
|
||||
# Mirrors the pre-cost behaviour of "premium turn → all calls
|
||||
# count" so free sub-agent calls during a premium turn still
|
||||
# contribute to the bill (they're $0 in practice anyway).
|
||||
if _premium_request_id and user_id:
|
||||
try:
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
|
|
@ -3070,11 +3083,11 @@ async def stream_new_chat(
|
|||
db_session=quota_session,
|
||||
user_id=UUID(user_id),
|
||||
request_id=_premium_request_id,
|
||||
actual_tokens=accumulator.grand_total,
|
||||
reserved_tokens=_premium_reserved,
|
||||
actual_micros=accumulator.total_cost_micros,
|
||||
reserved_micros=_premium_reserved_micros,
|
||||
)
|
||||
_premium_request_id = None
|
||||
_premium_reserved = 0
|
||||
_premium_reserved_micros = 0
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to finalize premium quota for user %s",
|
||||
|
|
@ -3084,9 +3097,10 @@ async def stream_new_chat(
|
|||
|
||||
usage_summary = accumulator.per_message_summary()
|
||||
_perf_log.info(
|
||||
"[token_usage] normal new_chat: calls=%d total=%d summary=%s",
|
||||
"[token_usage] normal new_chat: calls=%d total=%d cost_micros=%d summary=%s",
|
||||
len(accumulator.calls),
|
||||
accumulator.grand_total,
|
||||
accumulator.total_cost_micros,
|
||||
usage_summary,
|
||||
)
|
||||
if usage_summary:
|
||||
|
|
@ -3097,6 +3111,7 @@ async def stream_new_chat(
|
|||
"prompt_tokens": accumulator.total_prompt_tokens,
|
||||
"completion_tokens": accumulator.total_completion_tokens,
|
||||
"total_tokens": accumulator.grand_total,
|
||||
"cost_micros": accumulator.total_cost_micros,
|
||||
"call_details": accumulator.serialized_calls(),
|
||||
},
|
||||
)
|
||||
|
|
@ -3190,7 +3205,7 @@ async def stream_new_chat(
|
|||
end_turn(str(chat_id))
|
||||
|
||||
# Release premium reservation if not finalized
|
||||
if _premium_request_id and _premium_reserved > 0 and user_id:
|
||||
if _premium_request_id and _premium_reserved_micros > 0 and user_id:
|
||||
try:
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
|
||||
|
|
@ -3198,9 +3213,9 @@ async def stream_new_chat(
|
|||
await TokenQuotaService.premium_release(
|
||||
db_session=quota_session,
|
||||
user_id=UUID(user_id),
|
||||
reserved_tokens=_premium_reserved,
|
||||
reserved_micros=_premium_reserved_micros,
|
||||
)
|
||||
_premium_reserved = 0
|
||||
_premium_reserved_micros = 0
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to release premium quota for user %s", user_id
|
||||
|
|
@ -3369,8 +3384,8 @@ async def stream_resume_chat(
|
|||
"[stream_resume] LLM config loaded in %.3fs", time.perf_counter() - _t0
|
||||
)
|
||||
|
||||
# Premium quota reservation (same logic as stream_new_chat)
|
||||
_resume_premium_reserved = 0
|
||||
# Premium credit reservation (same logic as stream_new_chat).
|
||||
_resume_premium_reserved_micros = 0
|
||||
_resume_premium_request_id: str | None = None
|
||||
_resume_needs_premium = (
|
||||
agent_config is not None and user_id and agent_config.is_premium
|
||||
|
|
@ -3378,23 +3393,30 @@ async def stream_resume_chat(
|
|||
if _resume_needs_premium:
|
||||
import uuid as _uuid
|
||||
|
||||
from app.config import config as _app_config
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
from app.services.token_quota_service import (
|
||||
TokenQuotaService,
|
||||
estimate_call_reserve_micros,
|
||||
)
|
||||
|
||||
_resume_premium_request_id = _uuid.uuid4().hex[:16]
|
||||
reserve_amount = min(
|
||||
agent_config.quota_reserve_tokens
|
||||
or _app_config.QUOTA_MAX_RESERVE_PER_CALL,
|
||||
_app_config.QUOTA_MAX_RESERVE_PER_CALL,
|
||||
_resume_litellm_params = agent_config.litellm_params or {}
|
||||
_resume_base_model = (
|
||||
_resume_litellm_params.get("base_model")
|
||||
or agent_config.model_name
|
||||
or ""
|
||||
)
|
||||
reserve_amount_micros = estimate_call_reserve_micros(
|
||||
base_model=_resume_base_model,
|
||||
quota_reserve_tokens=agent_config.quota_reserve_tokens,
|
||||
)
|
||||
async with shielded_async_session() as quota_session:
|
||||
quota_result = await TokenQuotaService.premium_reserve(
|
||||
db_session=quota_session,
|
||||
user_id=UUID(user_id),
|
||||
request_id=_resume_premium_request_id,
|
||||
reserve_tokens=reserve_amount,
|
||||
reserve_micros=reserve_amount_micros,
|
||||
)
|
||||
_resume_premium_reserved = reserve_amount
|
||||
_resume_premium_reserved_micros = reserve_amount_micros
|
||||
if not quota_result.allowed:
|
||||
if requested_llm_config_id == 0:
|
||||
try:
|
||||
|
|
@ -3429,7 +3451,7 @@ async def stream_resume_chat(
|
|||
yield streaming_service.format_done()
|
||||
return
|
||||
_resume_premium_request_id = None
|
||||
_resume_premium_reserved = 0
|
||||
_resume_premium_reserved_micros = 0
|
||||
_log_chat_stream_error(
|
||||
flow="resume",
|
||||
error_kind="premium_quota_exhausted",
|
||||
|
|
@ -3746,9 +3768,10 @@ async def stream_resume_chat(
|
|||
if stream_result.is_interrupted:
|
||||
usage_summary = accumulator.per_message_summary()
|
||||
_perf_log.info(
|
||||
"[token_usage] interrupted resume_chat: calls=%d total=%d summary=%s",
|
||||
"[token_usage] interrupted resume_chat: calls=%d total=%d cost_micros=%d summary=%s",
|
||||
len(accumulator.calls),
|
||||
accumulator.grand_total,
|
||||
accumulator.total_cost_micros,
|
||||
usage_summary,
|
||||
)
|
||||
if usage_summary:
|
||||
|
|
@ -3759,6 +3782,7 @@ async def stream_resume_chat(
|
|||
"prompt_tokens": accumulator.total_prompt_tokens,
|
||||
"completion_tokens": accumulator.total_completion_tokens,
|
||||
"total_tokens": accumulator.grand_total,
|
||||
"cost_micros": accumulator.total_cost_micros,
|
||||
"call_details": accumulator.serialized_calls(),
|
||||
},
|
||||
)
|
||||
|
|
@ -3768,7 +3792,9 @@ async def stream_resume_chat(
|
|||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
# Finalize premium quota for resume path
|
||||
# Finalize premium credit debit for resume path with the actual
|
||||
# provider cost reported by LiteLLM (sum of cost across all
|
||||
# calls in the turn).
|
||||
if _resume_premium_request_id and user_id:
|
||||
try:
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
|
|
@ -3778,11 +3804,11 @@ async def stream_resume_chat(
|
|||
db_session=quota_session,
|
||||
user_id=UUID(user_id),
|
||||
request_id=_resume_premium_request_id,
|
||||
actual_tokens=accumulator.grand_total,
|
||||
reserved_tokens=_resume_premium_reserved,
|
||||
actual_micros=accumulator.total_cost_micros,
|
||||
reserved_micros=_resume_premium_reserved_micros,
|
||||
)
|
||||
_resume_premium_request_id = None
|
||||
_resume_premium_reserved = 0
|
||||
_resume_premium_reserved_micros = 0
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to finalize premium quota for user %s (resume)",
|
||||
|
|
@ -3792,9 +3818,10 @@ async def stream_resume_chat(
|
|||
|
||||
usage_summary = accumulator.per_message_summary()
|
||||
_perf_log.info(
|
||||
"[token_usage] normal resume_chat: calls=%d total=%d summary=%s",
|
||||
"[token_usage] normal resume_chat: calls=%d total=%d cost_micros=%d summary=%s",
|
||||
len(accumulator.calls),
|
||||
accumulator.grand_total,
|
||||
accumulator.total_cost_micros,
|
||||
usage_summary,
|
||||
)
|
||||
if usage_summary:
|
||||
|
|
@ -3805,6 +3832,7 @@ async def stream_resume_chat(
|
|||
"prompt_tokens": accumulator.total_prompt_tokens,
|
||||
"completion_tokens": accumulator.total_completion_tokens,
|
||||
"total_tokens": accumulator.grand_total,
|
||||
"cost_micros": accumulator.total_cost_micros,
|
||||
"call_details": accumulator.serialized_calls(),
|
||||
},
|
||||
)
|
||||
|
|
@ -3855,7 +3883,11 @@ async def stream_resume_chat(
|
|||
end_turn(str(chat_id))
|
||||
|
||||
# Release premium reservation if not finalized
|
||||
if _resume_premium_request_id and _resume_premium_reserved > 0 and user_id:
|
||||
if (
|
||||
_resume_premium_request_id
|
||||
and _resume_premium_reserved_micros > 0
|
||||
and user_id
|
||||
):
|
||||
try:
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
|
||||
|
|
@ -3863,9 +3895,9 @@ async def stream_resume_chat(
|
|||
await TokenQuotaService.premium_release(
|
||||
db_session=quota_session,
|
||||
user_id=UUID(user_id),
|
||||
reserved_tokens=_resume_premium_reserved,
|
||||
reserved_micros=_resume_premium_reserved_micros,
|
||||
)
|
||||
_resume_premium_reserved = 0
|
||||
_resume_premium_reserved_micros = 0
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to release premium quota for user %s (resume)", user_id
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue