feat(agents): emit metrics for model and tool calls

This commit is contained in:
Anish Sarkar 2026-05-21 23:02:36 +05:30
parent 6095b48b5f
commit ea3d0a6463
2 changed files with 190 additions and 17 deletions

View file

@ -16,13 +16,14 @@ dashboards expect.
from __future__ import annotations
import logging
import time
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import AIMessage, ToolMessage
from app.observability import otel as ot
from app.observability import metrics as ot_metrics, otel as ot
if TYPE_CHECKING: # pragma: no cover — type-only
from langchain.agents.middleware.types import (
@ -62,14 +63,37 @@ class OtelSpanMiddleware(AgentMiddleware):
return await handler(request)
model_id, provider = _resolve_model_attrs(request)
t0 = time.perf_counter()
with ot.model_call_span(model_id=model_id, provider=provider) as sp:
_annotate_model_request(sp, model_id=model_id, provider=provider)
try:
result = await handler(request)
except Exception:
ot_metrics.record_model_call_duration(
(time.perf_counter() - t0) * 1000,
model=model_id,
provider=provider,
)
# span context manager records + re-raises
raise
else:
_annotate_model_response(sp, result)
input_tokens, output_tokens = _annotate_model_response(
sp,
result,
model_id=model_id,
provider=provider,
)
ot_metrics.record_model_call_duration(
(time.perf_counter() - t0) * 1000,
model=model_id,
provider=provider,
)
ot_metrics.record_model_token_usage(
input_tokens=input_tokens,
output_tokens=output_tokens,
model=model_id,
provider=provider,
)
return result
# ------------------------------------------------------------------
@ -87,9 +111,24 @@ class OtelSpanMiddleware(AgentMiddleware):
tool_name = _resolve_tool_name(request)
input_size = _resolve_input_size(request)
t0 = time.perf_counter()
with ot.tool_call_span(tool_name, input_size=input_size) as sp:
result = await handler(request)
_annotate_tool_result(sp, result)
try:
result = await handler(request)
except Exception:
ot_metrics.record_tool_call_duration(
(time.perf_counter() - t0) * 1000,
tool_name=tool_name,
)
ot_metrics.record_tool_call_error(tool_name=tool_name)
raise
errored = _annotate_tool_result(sp, result)
ot_metrics.record_tool_call_duration(
(time.perf_counter() - t0) * 1000,
tool_name=tool_name,
)
if errored:
ot_metrics.record_tool_call_error(tool_name=tool_name)
return result
@ -154,8 +193,29 @@ def _resolve_input_size(request: Any) -> int | None:
return None
def _annotate_model_response(span: Any, result: Any) -> None:
def _annotate_model_request(
span: Any, *, model_id: str | None, provider: str | None
) -> None:
try:
span.set_attribute("gen_ai.operation.name", "chat")
if model_id:
span.set_attribute("gen_ai.request.model", model_id)
if provider:
span.set_attribute("gen_ai.provider.name", provider)
except Exception: # pragma: no cover — defensive
pass
def _annotate_model_response(
span: Any,
result: Any,
*,
model_id: str | None = None,
provider: str | None = None,
) -> tuple[int | None, int | None]:
"""Best-effort: attach prompt/completion token counts when available."""
input_tokens: int | None = None
output_tokens: int | None = None
try:
# ModelResponse may be a dataclass with .result containing AIMessage
msg: Any
@ -165,22 +225,42 @@ def _annotate_model_response(span: Any, result: Any) -> None:
inner = getattr(result, "result", None)
msg = inner[-1] if isinstance(inner, list) and inner else inner
if msg is None:
return
return None, None
if provider:
span.set_attribute("gen_ai.provider.name", provider)
if model_id:
span.set_attribute("gen_ai.request.model", model_id)
response_model = getattr(msg, "response_metadata", {}) or {}
if isinstance(response_model, dict):
response_model = (
response_model.get("model_name")
or response_model.get("model")
or response_model.get("model_id")
)
if not response_model:
response_model = model_id
if response_model:
span.set_attribute("gen_ai.response.model", str(response_model))
span.set_attribute("gen_ai.operation.name", "chat")
usage = getattr(msg, "usage_metadata", None) or {}
if isinstance(usage, dict):
if (n := usage.get("input_tokens")) is not None:
span.set_attribute("tokens.prompt", int(n))
input_tokens = int(n)
span.set_attribute("gen_ai.usage.input_tokens", input_tokens)
if (n := usage.get("output_tokens")) is not None:
span.set_attribute("tokens.completion", int(n))
output_tokens = int(n)
span.set_attribute("gen_ai.usage.output_tokens", output_tokens)
if (n := usage.get("total_tokens")) is not None:
span.set_attribute("tokens.total", int(n))
span.set_attribute("gen_ai.usage.total_tokens", int(n))
tool_calls = getattr(msg, "tool_calls", None) or []
span.set_attribute("model.tool_calls", len(tool_calls))
except Exception: # pragma: no cover — defensive
pass
return input_tokens, output_tokens
def _annotate_tool_result(span: Any, result: Any) -> None:
def _annotate_tool_result(span: Any, result: Any) -> bool:
errored = False
try:
if isinstance(result, ToolMessage):
content = (
@ -192,11 +272,14 @@ def _annotate_tool_result(span: Any, result: Any) -> None:
status = getattr(result, "status", None)
if isinstance(status, str):
span.set_attribute("tool.status", status)
errored = status.lower() == "error"
kwargs = getattr(result, "additional_kwargs", None) or {}
if isinstance(kwargs, dict) and kwargs.get("error"):
span.set_attribute("tool.error", True)
errored = True
except Exception: # pragma: no cover — defensive
pass
return errored
__all__ = ["OtelSpanMiddleware"]

View file

@ -23,6 +23,7 @@ pytestmark = pytest.mark.unit
@pytest.fixture(autouse=True)
def _disable_otel(monkeypatch: pytest.MonkeyPatch):
monkeypatch.delenv("OTEL_EXPORTER_OTLP_ENDPOINT", raising=False)
monkeypatch.delenv("OTEL_SDK_DISABLED", raising=False)
monkeypatch.setenv("SURFSENSE_DISABLE_OTEL", "true")
from app.observability import otel as ot
@ -99,16 +100,17 @@ class TestAnnotateModelResponse:
"total_tokens": 150,
},
)
_annotate_model_response(sp, msg)
sp.set_attribute.assert_any_call("tokens.prompt", 100)
sp.set_attribute.assert_any_call("tokens.completion", 50)
sp.set_attribute.assert_any_call("tokens.total", 150)
assert _annotate_model_response(sp, msg) == (100, 50)
sp.set_attribute.assert_any_call("gen_ai.usage.input_tokens", 100)
sp.set_attribute.assert_any_call("gen_ai.usage.output_tokens", 50)
sp.set_attribute.assert_any_call("gen_ai.usage.total_tokens", 150)
sp.set_attribute.assert_any_call("gen_ai.operation.name", "chat")
def test_handles_response_with_no_metadata(self) -> None:
sp = MagicMock()
msg = AIMessage(content="hello")
# Should not raise even when usage_metadata is missing
_annotate_model_response(sp, msg)
assert _annotate_model_response(sp, msg) == (None, None)
class TestAnnotateToolResult:
@ -119,7 +121,7 @@ class TestAnnotateToolResult:
tool_call_id="abc",
status="success",
)
_annotate_tool_result(sp, result)
assert _annotate_tool_result(sp, result) is False
sp.set_attribute.assert_any_call("tool.output.size", len("result text"))
sp.set_attribute.assert_any_call("tool.status", "success")
@ -130,7 +132,7 @@ class TestAnnotateToolResult:
tool_call_id="abc",
additional_kwargs={"error": {"code": "x"}},
)
_annotate_tool_result(sp, result)
assert _annotate_tool_result(sp, result) is True
sp.set_attribute.assert_any_call("tool.error", True)
@ -193,3 +195,91 @@ class TestMiddlewareIntegration:
assert result.content == "enabled"
finally:
ot.reload_for_tests()
async def test_enabled_model_call_records_metrics(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.delenv("SURFSENSE_DISABLE_OTEL", raising=False)
monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317")
from app.observability import otel as ot
duration_calls: list[dict[str, Any]] = []
token_calls: list[dict[str, Any]] = []
monkeypatch.setattr(
"app.agents.new_chat.middleware.otel_span.ot_metrics.record_model_call_duration",
lambda duration_ms, **attrs: duration_calls.append(
{"duration_ms": duration_ms, **attrs}
),
)
monkeypatch.setattr(
"app.agents.new_chat.middleware.otel_span.ot_metrics.record_model_token_usage",
lambda **attrs: token_calls.append(attrs),
)
ot.reload_for_tests()
try:
mw = OtelSpanMiddleware()
async def handler(req):
return AIMessage(
content="enabled",
usage_metadata={
"input_tokens": 3,
"output_tokens": 5,
"total_tokens": 8,
},
)
request = MagicMock()
request.model = MagicMock()
request.model.model_name = "gpt-4o"
request.model.provider = "openai"
await mw.awrap_model_call(request, handler)
assert duration_calls
assert token_calls == [
{
"input_tokens": 3,
"output_tokens": 5,
"model": "gpt-4o",
"provider": "openai",
}
]
finally:
ot.reload_for_tests()
async def test_enabled_tool_call_records_error_metric(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.delenv("SURFSENSE_DISABLE_OTEL", raising=False)
monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317")
from app.observability import otel as ot
errors: list[str] = []
monkeypatch.setattr(
"app.agents.new_chat.middleware.otel_span.ot_metrics.record_tool_call_error",
lambda *, tool_name: errors.append(tool_name),
)
monkeypatch.setattr(
"app.agents.new_chat.middleware.otel_span.ot_metrics.record_tool_call_duration",
lambda *args, **kwargs: None,
)
ot.reload_for_tests()
try:
mw = OtelSpanMiddleware()
async def handler(req):
return ToolMessage(
content="failed",
tool_call_id="abc",
status="error",
)
request = MagicMock()
request.tool = MagicMock()
request.tool.name = "web_search"
await mw.awrap_tool_call(request, handler)
assert errors == ["web_search"]
finally:
ot.reload_for_tests()