Merge pull request #1491 from AnishSarkar22/feat/unified-model-connections

feat: Fix model attribution for prefix-stripped token usage callbacks
This commit is contained in:
Rohan Verma 2026-06-14 17:50:48 -07:00 committed by GitHub
commit 69bdcf5946
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 190 additions and 32 deletions

View file

@ -32,6 +32,23 @@ from app.db import TokenUsage
logger = logging.getLogger(__name__)
def _bare_model_name(model: str) -> str:
"""Return a model identifier with any provider routing prefix stripped.
LiteLLM's ``get_llm_provider`` consumes the provider prefix we add in
``to_litellm`` (e.g. ``azure/gpt-5.2-chat`` ``gpt-5.2-chat`` because
``azure`` is in ``litellm.provider_list``). The token-tracking success
callback therefore reports ``kwargs["model"]`` *without* that prefix,
while model metadata is registered under the *prefixed* string. Normalising
both sides to the last path segment lets the two reconcile so the per-model
breakdown carries provider/display_name and the UI attributes the turn to
the correct connection instead of falling back to a bare-name collision.
"""
if not model:
return model
return model.split("/")[-1]
@dataclass
class TokenCallRecord:
model: str
@ -52,6 +69,12 @@ class TurnTokenAccumulator:
calls: list[TokenCallRecord] = field(default_factory=list)
model_metadata: dict[str, dict[str, str | None]] = field(default_factory=dict)
# Secondary index keyed by the bare model name (provider prefix stripped) so
# the LiteLLM callback — which never sees our routing prefix — can still
# reconcile its ``kwargs["model"]`` back to the registered metadata.
model_metadata_by_bare: dict[str, dict[str, str | None]] = field(
default_factory=dict
)
def register_model_metadata(
self,
@ -63,12 +86,28 @@ class TurnTokenAccumulator:
provider: str | None,
) -> None:
"""Attach resolved model metadata for later LiteLLM callback attribution."""
self.model_metadata[model] = {
metadata = {
"model_ref": model_ref,
"model_id": model_id,
"display_name": display_name,
"provider": provider,
}
self.model_metadata[model] = metadata
# Index every reconcilable alias: the prefixed string's bare form and
# the resolved ``model_id`` (which for some providers is itself the bare
# deployment LiteLLM reports). Exact lookups always take precedence.
self.model_metadata_by_bare[_bare_model_name(model)] = metadata
if model_id:
self.model_metadata_by_bare.setdefault(_bare_model_name(model_id), metadata)
def _lookup_metadata(self, model: str) -> dict[str, str | None]:
"""Resolve registered metadata for a callback model, tolerating the
provider-prefix stripping LiteLLM applies before the success callback
fires (see :func:`_bare_model_name`)."""
exact = self.model_metadata.get(model)
if exact is not None:
return exact
return self.model_metadata_by_bare.get(_bare_model_name(model), {})
def add(
self,
@ -79,7 +118,7 @@ class TurnTokenAccumulator:
cost_micros: int = 0,
call_kind: str = "chat",
) -> None:
metadata = self.model_metadata.get(model, {})
metadata = self._lookup_metadata(model)
self.calls.append(
TokenCallRecord(
model=model,