From ea3d0a64638c1bac1e3c6585fb2d53b7e6e17373 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 21 May 2026 23:02:36 +0530 Subject: [PATCH] feat(agents): emit metrics for model and tool calls --- .../agents/new_chat/middleware/otel_span.py | 103 +++++++++++++++-- .../unit/agents/new_chat/test_otel_span.py | 104 ++++++++++++++++-- 2 files changed, 190 insertions(+), 17 deletions(-) diff --git a/surfsense_backend/app/agents/new_chat/middleware/otel_span.py b/surfsense_backend/app/agents/new_chat/middleware/otel_span.py index cfe1edae4..ecaa042a9 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/otel_span.py +++ b/surfsense_backend/app/agents/new_chat/middleware/otel_span.py @@ -16,13 +16,14 @@ dashboards expect. from __future__ import annotations import logging +import time from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING, Any from langchain.agents.middleware import AgentMiddleware from langchain_core.messages import AIMessage, ToolMessage -from app.observability import otel as ot +from app.observability import metrics as ot_metrics, otel as ot if TYPE_CHECKING: # pragma: no cover — type-only from langchain.agents.middleware.types import ( @@ -62,14 +63,37 @@ class OtelSpanMiddleware(AgentMiddleware): return await handler(request) model_id, provider = _resolve_model_attrs(request) + t0 = time.perf_counter() with ot.model_call_span(model_id=model_id, provider=provider) as sp: + _annotate_model_request(sp, model_id=model_id, provider=provider) try: result = await handler(request) except Exception: + ot_metrics.record_model_call_duration( + (time.perf_counter() - t0) * 1000, + model=model_id, + provider=provider, + ) # span context manager records + re-raises raise else: - _annotate_model_response(sp, result) + input_tokens, output_tokens = _annotate_model_response( + sp, + result, + model_id=model_id, + provider=provider, + ) + ot_metrics.record_model_call_duration( + (time.perf_counter() - t0) * 1000, + model=model_id, + provider=provider, + ) + ot_metrics.record_model_token_usage( + input_tokens=input_tokens, + output_tokens=output_tokens, + model=model_id, + provider=provider, + ) return result # ------------------------------------------------------------------ @@ -87,9 +111,24 @@ class OtelSpanMiddleware(AgentMiddleware): tool_name = _resolve_tool_name(request) input_size = _resolve_input_size(request) + t0 = time.perf_counter() with ot.tool_call_span(tool_name, input_size=input_size) as sp: - result = await handler(request) - _annotate_tool_result(sp, result) + try: + result = await handler(request) + except Exception: + ot_metrics.record_tool_call_duration( + (time.perf_counter() - t0) * 1000, + tool_name=tool_name, + ) + ot_metrics.record_tool_call_error(tool_name=tool_name) + raise + errored = _annotate_tool_result(sp, result) + ot_metrics.record_tool_call_duration( + (time.perf_counter() - t0) * 1000, + tool_name=tool_name, + ) + if errored: + ot_metrics.record_tool_call_error(tool_name=tool_name) return result @@ -154,8 +193,29 @@ def _resolve_input_size(request: Any) -> int | None: return None -def _annotate_model_response(span: Any, result: Any) -> None: +def _annotate_model_request( + span: Any, *, model_id: str | None, provider: str | None +) -> None: + try: + span.set_attribute("gen_ai.operation.name", "chat") + if model_id: + span.set_attribute("gen_ai.request.model", model_id) + if provider: + span.set_attribute("gen_ai.provider.name", provider) + except Exception: # pragma: no cover — defensive + pass + + +def _annotate_model_response( + span: Any, + result: Any, + *, + model_id: str | None = None, + provider: str | None = None, +) -> tuple[int | None, int | None]: """Best-effort: attach prompt/completion token counts when available.""" + input_tokens: int | None = None + output_tokens: int | None = None try: # ModelResponse may be a dataclass with .result containing AIMessage msg: Any @@ -165,22 +225,42 @@ def _annotate_model_response(span: Any, result: Any) -> None: inner = getattr(result, "result", None) msg = inner[-1] if isinstance(inner, list) and inner else inner if msg is None: - return + return None, None + if provider: + span.set_attribute("gen_ai.provider.name", provider) + if model_id: + span.set_attribute("gen_ai.request.model", model_id) + response_model = getattr(msg, "response_metadata", {}) or {} + if isinstance(response_model, dict): + response_model = ( + response_model.get("model_name") + or response_model.get("model") + or response_model.get("model_id") + ) + if not response_model: + response_model = model_id + if response_model: + span.set_attribute("gen_ai.response.model", str(response_model)) + span.set_attribute("gen_ai.operation.name", "chat") usage = getattr(msg, "usage_metadata", None) or {} if isinstance(usage, dict): if (n := usage.get("input_tokens")) is not None: - span.set_attribute("tokens.prompt", int(n)) + input_tokens = int(n) + span.set_attribute("gen_ai.usage.input_tokens", input_tokens) if (n := usage.get("output_tokens")) is not None: - span.set_attribute("tokens.completion", int(n)) + output_tokens = int(n) + span.set_attribute("gen_ai.usage.output_tokens", output_tokens) if (n := usage.get("total_tokens")) is not None: - span.set_attribute("tokens.total", int(n)) + span.set_attribute("gen_ai.usage.total_tokens", int(n)) tool_calls = getattr(msg, "tool_calls", None) or [] span.set_attribute("model.tool_calls", len(tool_calls)) except Exception: # pragma: no cover — defensive pass + return input_tokens, output_tokens -def _annotate_tool_result(span: Any, result: Any) -> None: +def _annotate_tool_result(span: Any, result: Any) -> bool: + errored = False try: if isinstance(result, ToolMessage): content = ( @@ -192,11 +272,14 @@ def _annotate_tool_result(span: Any, result: Any) -> None: status = getattr(result, "status", None) if isinstance(status, str): span.set_attribute("tool.status", status) + errored = status.lower() == "error" kwargs = getattr(result, "additional_kwargs", None) or {} if isinstance(kwargs, dict) and kwargs.get("error"): span.set_attribute("tool.error", True) + errored = True except Exception: # pragma: no cover — defensive pass + return errored __all__ = ["OtelSpanMiddleware"] diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_otel_span.py b/surfsense_backend/tests/unit/agents/new_chat/test_otel_span.py index 55434c04d..dc59c6dac 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_otel_span.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_otel_span.py @@ -23,6 +23,7 @@ pytestmark = pytest.mark.unit @pytest.fixture(autouse=True) def _disable_otel(monkeypatch: pytest.MonkeyPatch): monkeypatch.delenv("OTEL_EXPORTER_OTLP_ENDPOINT", raising=False) + monkeypatch.delenv("OTEL_SDK_DISABLED", raising=False) monkeypatch.setenv("SURFSENSE_DISABLE_OTEL", "true") from app.observability import otel as ot @@ -99,16 +100,17 @@ class TestAnnotateModelResponse: "total_tokens": 150, }, ) - _annotate_model_response(sp, msg) - sp.set_attribute.assert_any_call("tokens.prompt", 100) - sp.set_attribute.assert_any_call("tokens.completion", 50) - sp.set_attribute.assert_any_call("tokens.total", 150) + assert _annotate_model_response(sp, msg) == (100, 50) + sp.set_attribute.assert_any_call("gen_ai.usage.input_tokens", 100) + sp.set_attribute.assert_any_call("gen_ai.usage.output_tokens", 50) + sp.set_attribute.assert_any_call("gen_ai.usage.total_tokens", 150) + sp.set_attribute.assert_any_call("gen_ai.operation.name", "chat") def test_handles_response_with_no_metadata(self) -> None: sp = MagicMock() msg = AIMessage(content="hello") # Should not raise even when usage_metadata is missing - _annotate_model_response(sp, msg) + assert _annotate_model_response(sp, msg) == (None, None) class TestAnnotateToolResult: @@ -119,7 +121,7 @@ class TestAnnotateToolResult: tool_call_id="abc", status="success", ) - _annotate_tool_result(sp, result) + assert _annotate_tool_result(sp, result) is False sp.set_attribute.assert_any_call("tool.output.size", len("result text")) sp.set_attribute.assert_any_call("tool.status", "success") @@ -130,7 +132,7 @@ class TestAnnotateToolResult: tool_call_id="abc", additional_kwargs={"error": {"code": "x"}}, ) - _annotate_tool_result(sp, result) + assert _annotate_tool_result(sp, result) is True sp.set_attribute.assert_any_call("tool.error", True) @@ -193,3 +195,91 @@ class TestMiddlewareIntegration: assert result.content == "enabled" finally: ot.reload_for_tests() + + async def test_enabled_model_call_records_metrics( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.delenv("SURFSENSE_DISABLE_OTEL", raising=False) + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317") + from app.observability import otel as ot + + duration_calls: list[dict[str, Any]] = [] + token_calls: list[dict[str, Any]] = [] + monkeypatch.setattr( + "app.agents.new_chat.middleware.otel_span.ot_metrics.record_model_call_duration", + lambda duration_ms, **attrs: duration_calls.append( + {"duration_ms": duration_ms, **attrs} + ), + ) + monkeypatch.setattr( + "app.agents.new_chat.middleware.otel_span.ot_metrics.record_model_token_usage", + lambda **attrs: token_calls.append(attrs), + ) + + ot.reload_for_tests() + try: + mw = OtelSpanMiddleware() + + async def handler(req): + return AIMessage( + content="enabled", + usage_metadata={ + "input_tokens": 3, + "output_tokens": 5, + "total_tokens": 8, + }, + ) + + request = MagicMock() + request.model = MagicMock() + request.model.model_name = "gpt-4o" + request.model.provider = "openai" + await mw.awrap_model_call(request, handler) + + assert duration_calls + assert token_calls == [ + { + "input_tokens": 3, + "output_tokens": 5, + "model": "gpt-4o", + "provider": "openai", + } + ] + finally: + ot.reload_for_tests() + + async def test_enabled_tool_call_records_error_metric( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.delenv("SURFSENSE_DISABLE_OTEL", raising=False) + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317") + from app.observability import otel as ot + + errors: list[str] = [] + monkeypatch.setattr( + "app.agents.new_chat.middleware.otel_span.ot_metrics.record_tool_call_error", + lambda *, tool_name: errors.append(tool_name), + ) + monkeypatch.setattr( + "app.agents.new_chat.middleware.otel_span.ot_metrics.record_tool_call_duration", + lambda *args, **kwargs: None, + ) + + ot.reload_for_tests() + try: + mw = OtelSpanMiddleware() + + async def handler(req): + return ToolMessage( + content="failed", + tool_call_id="abc", + status="error", + ) + + request = MagicMock() + request.tool = MagicMock() + request.tool.name = "web_search" + await mw.awrap_tool_call(request, handler) + assert errors == ["web_search"] + finally: + ot.reload_for_tests()