feat(agents): emit metrics for model and tool calls

This commit is contained in:
Anish Sarkar 2026-05-21 23:02:36 +05:30
parent 6095b48b5f
commit ea3d0a6463
2 changed files with 190 additions and 17 deletions

View file

@ -16,13 +16,14 @@ dashboards expect.
from __future__ import annotations
import logging
import time
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import AIMessage, ToolMessage
from app.observability import otel as ot
from app.observability import metrics as ot_metrics, otel as ot
if TYPE_CHECKING: # pragma: no cover — type-only
from langchain.agents.middleware.types import (
@ -62,14 +63,37 @@ class OtelSpanMiddleware(AgentMiddleware):
return await handler(request)
model_id, provider = _resolve_model_attrs(request)
t0 = time.perf_counter()
with ot.model_call_span(model_id=model_id, provider=provider) as sp:
_annotate_model_request(sp, model_id=model_id, provider=provider)
try:
result = await handler(request)
except Exception:
ot_metrics.record_model_call_duration(
(time.perf_counter() - t0) * 1000,
model=model_id,
provider=provider,
)
# span context manager records + re-raises
raise
else:
_annotate_model_response(sp, result)
input_tokens, output_tokens = _annotate_model_response(
sp,
result,
model_id=model_id,
provider=provider,
)
ot_metrics.record_model_call_duration(
(time.perf_counter() - t0) * 1000,
model=model_id,
provider=provider,
)
ot_metrics.record_model_token_usage(
input_tokens=input_tokens,
output_tokens=output_tokens,
model=model_id,
provider=provider,
)
return result
# ------------------------------------------------------------------
@ -87,9 +111,24 @@ class OtelSpanMiddleware(AgentMiddleware):
tool_name = _resolve_tool_name(request)
input_size = _resolve_input_size(request)
t0 = time.perf_counter()
with ot.tool_call_span(tool_name, input_size=input_size) as sp:
result = await handler(request)
_annotate_tool_result(sp, result)
try:
result = await handler(request)
except Exception:
ot_metrics.record_tool_call_duration(
(time.perf_counter() - t0) * 1000,
tool_name=tool_name,
)
ot_metrics.record_tool_call_error(tool_name=tool_name)
raise
errored = _annotate_tool_result(sp, result)
ot_metrics.record_tool_call_duration(
(time.perf_counter() - t0) * 1000,
tool_name=tool_name,
)
if errored:
ot_metrics.record_tool_call_error(tool_name=tool_name)
return result
@ -154,8 +193,29 @@ def _resolve_input_size(request: Any) -> int | None:
return None
def _annotate_model_response(span: Any, result: Any) -> None:
def _annotate_model_request(
span: Any, *, model_id: str | None, provider: str | None
) -> None:
try:
span.set_attribute("gen_ai.operation.name", "chat")
if model_id:
span.set_attribute("gen_ai.request.model", model_id)
if provider:
span.set_attribute("gen_ai.provider.name", provider)
except Exception: # pragma: no cover — defensive
pass
def _annotate_model_response(
span: Any,
result: Any,
*,
model_id: str | None = None,
provider: str | None = None,
) -> tuple[int | None, int | None]:
"""Best-effort: attach prompt/completion token counts when available."""
input_tokens: int | None = None
output_tokens: int | None = None
try:
# ModelResponse may be a dataclass with .result containing AIMessage
msg: Any
@ -165,22 +225,42 @@ def _annotate_model_response(span: Any, result: Any) -> None:
inner = getattr(result, "result", None)
msg = inner[-1] if isinstance(inner, list) and inner else inner
if msg is None:
return
return None, None
if provider:
span.set_attribute("gen_ai.provider.name", provider)
if model_id:
span.set_attribute("gen_ai.request.model", model_id)
response_model = getattr(msg, "response_metadata", {}) or {}
if isinstance(response_model, dict):
response_model = (
response_model.get("model_name")
or response_model.get("model")
or response_model.get("model_id")
)
if not response_model:
response_model = model_id
if response_model:
span.set_attribute("gen_ai.response.model", str(response_model))
span.set_attribute("gen_ai.operation.name", "chat")
usage = getattr(msg, "usage_metadata", None) or {}
if isinstance(usage, dict):
if (n := usage.get("input_tokens")) is not None:
span.set_attribute("tokens.prompt", int(n))
input_tokens = int(n)
span.set_attribute("gen_ai.usage.input_tokens", input_tokens)
if (n := usage.get("output_tokens")) is not None:
span.set_attribute("tokens.completion", int(n))
output_tokens = int(n)
span.set_attribute("gen_ai.usage.output_tokens", output_tokens)
if (n := usage.get("total_tokens")) is not None:
span.set_attribute("tokens.total", int(n))
span.set_attribute("gen_ai.usage.total_tokens", int(n))
tool_calls = getattr(msg, "tool_calls", None) or []
span.set_attribute("model.tool_calls", len(tool_calls))
except Exception: # pragma: no cover — defensive
pass
return input_tokens, output_tokens
def _annotate_tool_result(span: Any, result: Any) -> None:
def _annotate_tool_result(span: Any, result: Any) -> bool:
errored = False
try:
if isinstance(result, ToolMessage):
content = (
@ -192,11 +272,14 @@ def _annotate_tool_result(span: Any, result: Any) -> None:
status = getattr(result, "status", None)
if isinstance(status, str):
span.set_attribute("tool.status", status)
errored = status.lower() == "error"
kwargs = getattr(result, "additional_kwargs", None) or {}
if isinstance(kwargs, dict) and kwargs.get("error"):
span.set_attribute("tool.error", True)
errored = True
except Exception: # pragma: no cover — defensive
pass
return errored
__all__ = ["OtelSpanMiddleware"]