mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-20 21:18:13 +02:00
Merge pull request #1491 from AnishSarkar22/feat/unified-model-connections
feat: Fix model attribution for prefix-stripped token usage callbacks
This commit is contained in:
commit
69bdcf5946
7 changed files with 190 additions and 32 deletions
|
|
@ -112,6 +112,77 @@ def test_per_message_summary_groups_cost_by_model():
|
|||
assert summary["gpt-4o-mini"]["cost_micros"] == 200
|
||||
|
||||
|
||||
def test_add_reconciles_metadata_when_litellm_strips_provider_prefix():
|
||||
"""Regression: LiteLLM's ``get_llm_provider`` strips the provider prefix we
|
||||
add in ``to_litellm`` (``azure/gpt-5.2-chat`` → ``gpt-5.2-chat`` because
|
||||
``azure`` is in ``litellm.provider_list``), so the success callback reports
|
||||
the bare model. Metadata registered under the *prefixed* string must still
|
||||
attach to the call so the per-model breakdown carries provider/display_name
|
||||
— otherwise the UI falls back to a bare-name collision and mis-attributes an
|
||||
Azure turn to an OpenRouter model (e.g. shows "OpenAI: GPT-5.2 Chat").
|
||||
"""
|
||||
from app.services.token_tracking_service import TurnTokenAccumulator
|
||||
|
||||
acc = TurnTokenAccumulator()
|
||||
acc.register_model_metadata(
|
||||
model="azure/gpt-5.2-chat",
|
||||
model_ref="global:-1",
|
||||
model_id="gpt-5.2-chat",
|
||||
display_name="Azure GPT 5.2",
|
||||
provider="azure",
|
||||
)
|
||||
# LiteLLM callback fires with the prefix-stripped model name.
|
||||
acc.add(
|
||||
model="gpt-5.2-chat",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
total_tokens=150,
|
||||
cost_micros=4_000,
|
||||
)
|
||||
|
||||
summary = acc.per_message_summary()
|
||||
entry = summary["gpt-5.2-chat"]
|
||||
assert entry["provider"] == "azure"
|
||||
assert entry["display_name"] == "Azure GPT 5.2"
|
||||
assert entry["model_id"] == "gpt-5.2-chat"
|
||||
assert entry["model_ref"] == "global:-1"
|
||||
|
||||
|
||||
def test_add_prefers_exact_metadata_over_bare_alias():
|
||||
"""When the callback model matches a registered key exactly, the exact
|
||||
metadata wins even if another model shares the same bare name — so a turn
|
||||
that legitimately used two same-named deployments stays correctly
|
||||
attributed."""
|
||||
from app.services.token_tracking_service import TurnTokenAccumulator
|
||||
|
||||
acc = TurnTokenAccumulator()
|
||||
acc.register_model_metadata(
|
||||
model="azure/gpt-5.2-chat",
|
||||
model_ref="global:-1",
|
||||
model_id="gpt-5.2-chat",
|
||||
display_name="Azure GPT 5.2",
|
||||
provider="azure",
|
||||
)
|
||||
acc.register_model_metadata(
|
||||
model="openai/gpt-5.2-chat",
|
||||
model_ref="db:7",
|
||||
model_id="gpt-5.2-chat",
|
||||
display_name="OpenAI GPT 5.2",
|
||||
provider="openai",
|
||||
)
|
||||
acc.add(
|
||||
model="openai/gpt-5.2-chat",
|
||||
prompt_tokens=10,
|
||||
completion_tokens=5,
|
||||
total_tokens=15,
|
||||
cost_micros=100,
|
||||
)
|
||||
|
||||
entry = acc.per_message_summary()["openai/gpt-5.2-chat"]
|
||||
assert entry["provider"] == "openai"
|
||||
assert entry["display_name"] == "OpenAI GPT 5.2"
|
||||
|
||||
|
||||
def test_serialized_calls_includes_cost_micros():
|
||||
"""``serialized_calls`` is what flows into the SSE ``call_details``
|
||||
payload; cost_micros must be present on each entry so the FE message-info
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue