mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-25 19:15:18 +02:00
feat(agents): emit metrics for model and tool calls
This commit is contained in:
parent
6095b48b5f
commit
ea3d0a6463
2 changed files with 190 additions and 17 deletions
|
|
@ -16,13 +16,14 @@ dashboards expect.
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
|
||||
from app.observability import otel as ot
|
||||
from app.observability import metrics as ot_metrics, otel as ot
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover — type-only
|
||||
from langchain.agents.middleware.types import (
|
||||
|
|
@ -62,14 +63,37 @@ class OtelSpanMiddleware(AgentMiddleware):
|
|||
return await handler(request)
|
||||
|
||||
model_id, provider = _resolve_model_attrs(request)
|
||||
t0 = time.perf_counter()
|
||||
with ot.model_call_span(model_id=model_id, provider=provider) as sp:
|
||||
_annotate_model_request(sp, model_id=model_id, provider=provider)
|
||||
try:
|
||||
result = await handler(request)
|
||||
except Exception:
|
||||
ot_metrics.record_model_call_duration(
|
||||
(time.perf_counter() - t0) * 1000,
|
||||
model=model_id,
|
||||
provider=provider,
|
||||
)
|
||||
# span context manager records + re-raises
|
||||
raise
|
||||
else:
|
||||
_annotate_model_response(sp, result)
|
||||
input_tokens, output_tokens = _annotate_model_response(
|
||||
sp,
|
||||
result,
|
||||
model_id=model_id,
|
||||
provider=provider,
|
||||
)
|
||||
ot_metrics.record_model_call_duration(
|
||||
(time.perf_counter() - t0) * 1000,
|
||||
model=model_id,
|
||||
provider=provider,
|
||||
)
|
||||
ot_metrics.record_model_token_usage(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
model=model_id,
|
||||
provider=provider,
|
||||
)
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
|
|
@ -87,9 +111,24 @@ class OtelSpanMiddleware(AgentMiddleware):
|
|||
tool_name = _resolve_tool_name(request)
|
||||
input_size = _resolve_input_size(request)
|
||||
|
||||
t0 = time.perf_counter()
|
||||
with ot.tool_call_span(tool_name, input_size=input_size) as sp:
|
||||
result = await handler(request)
|
||||
_annotate_tool_result(sp, result)
|
||||
try:
|
||||
result = await handler(request)
|
||||
except Exception:
|
||||
ot_metrics.record_tool_call_duration(
|
||||
(time.perf_counter() - t0) * 1000,
|
||||
tool_name=tool_name,
|
||||
)
|
||||
ot_metrics.record_tool_call_error(tool_name=tool_name)
|
||||
raise
|
||||
errored = _annotate_tool_result(sp, result)
|
||||
ot_metrics.record_tool_call_duration(
|
||||
(time.perf_counter() - t0) * 1000,
|
||||
tool_name=tool_name,
|
||||
)
|
||||
if errored:
|
||||
ot_metrics.record_tool_call_error(tool_name=tool_name)
|
||||
return result
|
||||
|
||||
|
||||
|
|
@ -154,8 +193,29 @@ def _resolve_input_size(request: Any) -> int | None:
|
|||
return None
|
||||
|
||||
|
||||
def _annotate_model_response(span: Any, result: Any) -> None:
|
||||
def _annotate_model_request(
|
||||
span: Any, *, model_id: str | None, provider: str | None
|
||||
) -> None:
|
||||
try:
|
||||
span.set_attribute("gen_ai.operation.name", "chat")
|
||||
if model_id:
|
||||
span.set_attribute("gen_ai.request.model", model_id)
|
||||
if provider:
|
||||
span.set_attribute("gen_ai.provider.name", provider)
|
||||
except Exception: # pragma: no cover — defensive
|
||||
pass
|
||||
|
||||
|
||||
def _annotate_model_response(
|
||||
span: Any,
|
||||
result: Any,
|
||||
*,
|
||||
model_id: str | None = None,
|
||||
provider: str | None = None,
|
||||
) -> tuple[int | None, int | None]:
|
||||
"""Best-effort: attach prompt/completion token counts when available."""
|
||||
input_tokens: int | None = None
|
||||
output_tokens: int | None = None
|
||||
try:
|
||||
# ModelResponse may be a dataclass with .result containing AIMessage
|
||||
msg: Any
|
||||
|
|
@ -165,22 +225,42 @@ def _annotate_model_response(span: Any, result: Any) -> None:
|
|||
inner = getattr(result, "result", None)
|
||||
msg = inner[-1] if isinstance(inner, list) and inner else inner
|
||||
if msg is None:
|
||||
return
|
||||
return None, None
|
||||
if provider:
|
||||
span.set_attribute("gen_ai.provider.name", provider)
|
||||
if model_id:
|
||||
span.set_attribute("gen_ai.request.model", model_id)
|
||||
response_model = getattr(msg, "response_metadata", {}) or {}
|
||||
if isinstance(response_model, dict):
|
||||
response_model = (
|
||||
response_model.get("model_name")
|
||||
or response_model.get("model")
|
||||
or response_model.get("model_id")
|
||||
)
|
||||
if not response_model:
|
||||
response_model = model_id
|
||||
if response_model:
|
||||
span.set_attribute("gen_ai.response.model", str(response_model))
|
||||
span.set_attribute("gen_ai.operation.name", "chat")
|
||||
usage = getattr(msg, "usage_metadata", None) or {}
|
||||
if isinstance(usage, dict):
|
||||
if (n := usage.get("input_tokens")) is not None:
|
||||
span.set_attribute("tokens.prompt", int(n))
|
||||
input_tokens = int(n)
|
||||
span.set_attribute("gen_ai.usage.input_tokens", input_tokens)
|
||||
if (n := usage.get("output_tokens")) is not None:
|
||||
span.set_attribute("tokens.completion", int(n))
|
||||
output_tokens = int(n)
|
||||
span.set_attribute("gen_ai.usage.output_tokens", output_tokens)
|
||||
if (n := usage.get("total_tokens")) is not None:
|
||||
span.set_attribute("tokens.total", int(n))
|
||||
span.set_attribute("gen_ai.usage.total_tokens", int(n))
|
||||
tool_calls = getattr(msg, "tool_calls", None) or []
|
||||
span.set_attribute("model.tool_calls", len(tool_calls))
|
||||
except Exception: # pragma: no cover — defensive
|
||||
pass
|
||||
return input_tokens, output_tokens
|
||||
|
||||
|
||||
def _annotate_tool_result(span: Any, result: Any) -> None:
|
||||
def _annotate_tool_result(span: Any, result: Any) -> bool:
|
||||
errored = False
|
||||
try:
|
||||
if isinstance(result, ToolMessage):
|
||||
content = (
|
||||
|
|
@ -192,11 +272,14 @@ def _annotate_tool_result(span: Any, result: Any) -> None:
|
|||
status = getattr(result, "status", None)
|
||||
if isinstance(status, str):
|
||||
span.set_attribute("tool.status", status)
|
||||
errored = status.lower() == "error"
|
||||
kwargs = getattr(result, "additional_kwargs", None) or {}
|
||||
if isinstance(kwargs, dict) and kwargs.get("error"):
|
||||
span.set_attribute("tool.error", True)
|
||||
errored = True
|
||||
except Exception: # pragma: no cover — defensive
|
||||
pass
|
||||
return errored
|
||||
|
||||
|
||||
__all__ = ["OtelSpanMiddleware"]
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ pytestmark = pytest.mark.unit
|
|||
@pytest.fixture(autouse=True)
|
||||
def _disable_otel(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.delenv("OTEL_EXPORTER_OTLP_ENDPOINT", raising=False)
|
||||
monkeypatch.delenv("OTEL_SDK_DISABLED", raising=False)
|
||||
monkeypatch.setenv("SURFSENSE_DISABLE_OTEL", "true")
|
||||
from app.observability import otel as ot
|
||||
|
||||
|
|
@ -99,16 +100,17 @@ class TestAnnotateModelResponse:
|
|||
"total_tokens": 150,
|
||||
},
|
||||
)
|
||||
_annotate_model_response(sp, msg)
|
||||
sp.set_attribute.assert_any_call("tokens.prompt", 100)
|
||||
sp.set_attribute.assert_any_call("tokens.completion", 50)
|
||||
sp.set_attribute.assert_any_call("tokens.total", 150)
|
||||
assert _annotate_model_response(sp, msg) == (100, 50)
|
||||
sp.set_attribute.assert_any_call("gen_ai.usage.input_tokens", 100)
|
||||
sp.set_attribute.assert_any_call("gen_ai.usage.output_tokens", 50)
|
||||
sp.set_attribute.assert_any_call("gen_ai.usage.total_tokens", 150)
|
||||
sp.set_attribute.assert_any_call("gen_ai.operation.name", "chat")
|
||||
|
||||
def test_handles_response_with_no_metadata(self) -> None:
|
||||
sp = MagicMock()
|
||||
msg = AIMessage(content="hello")
|
||||
# Should not raise even when usage_metadata is missing
|
||||
_annotate_model_response(sp, msg)
|
||||
assert _annotate_model_response(sp, msg) == (None, None)
|
||||
|
||||
|
||||
class TestAnnotateToolResult:
|
||||
|
|
@ -119,7 +121,7 @@ class TestAnnotateToolResult:
|
|||
tool_call_id="abc",
|
||||
status="success",
|
||||
)
|
||||
_annotate_tool_result(sp, result)
|
||||
assert _annotate_tool_result(sp, result) is False
|
||||
sp.set_attribute.assert_any_call("tool.output.size", len("result text"))
|
||||
sp.set_attribute.assert_any_call("tool.status", "success")
|
||||
|
||||
|
|
@ -130,7 +132,7 @@ class TestAnnotateToolResult:
|
|||
tool_call_id="abc",
|
||||
additional_kwargs={"error": {"code": "x"}},
|
||||
)
|
||||
_annotate_tool_result(sp, result)
|
||||
assert _annotate_tool_result(sp, result) is True
|
||||
sp.set_attribute.assert_any_call("tool.error", True)
|
||||
|
||||
|
||||
|
|
@ -193,3 +195,91 @@ class TestMiddlewareIntegration:
|
|||
assert result.content == "enabled"
|
||||
finally:
|
||||
ot.reload_for_tests()
|
||||
|
||||
async def test_enabled_model_call_records_metrics(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
monkeypatch.delenv("SURFSENSE_DISABLE_OTEL", raising=False)
|
||||
monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317")
|
||||
from app.observability import otel as ot
|
||||
|
||||
duration_calls: list[dict[str, Any]] = []
|
||||
token_calls: list[dict[str, Any]] = []
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.otel_span.ot_metrics.record_model_call_duration",
|
||||
lambda duration_ms, **attrs: duration_calls.append(
|
||||
{"duration_ms": duration_ms, **attrs}
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.otel_span.ot_metrics.record_model_token_usage",
|
||||
lambda **attrs: token_calls.append(attrs),
|
||||
)
|
||||
|
||||
ot.reload_for_tests()
|
||||
try:
|
||||
mw = OtelSpanMiddleware()
|
||||
|
||||
async def handler(req):
|
||||
return AIMessage(
|
||||
content="enabled",
|
||||
usage_metadata={
|
||||
"input_tokens": 3,
|
||||
"output_tokens": 5,
|
||||
"total_tokens": 8,
|
||||
},
|
||||
)
|
||||
|
||||
request = MagicMock()
|
||||
request.model = MagicMock()
|
||||
request.model.model_name = "gpt-4o"
|
||||
request.model.provider = "openai"
|
||||
await mw.awrap_model_call(request, handler)
|
||||
|
||||
assert duration_calls
|
||||
assert token_calls == [
|
||||
{
|
||||
"input_tokens": 3,
|
||||
"output_tokens": 5,
|
||||
"model": "gpt-4o",
|
||||
"provider": "openai",
|
||||
}
|
||||
]
|
||||
finally:
|
||||
ot.reload_for_tests()
|
||||
|
||||
async def test_enabled_tool_call_records_error_metric(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
monkeypatch.delenv("SURFSENSE_DISABLE_OTEL", raising=False)
|
||||
monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317")
|
||||
from app.observability import otel as ot
|
||||
|
||||
errors: list[str] = []
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.otel_span.ot_metrics.record_tool_call_error",
|
||||
lambda *, tool_name: errors.append(tool_name),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.otel_span.ot_metrics.record_tool_call_duration",
|
||||
lambda *args, **kwargs: None,
|
||||
)
|
||||
|
||||
ot.reload_for_tests()
|
||||
try:
|
||||
mw = OtelSpanMiddleware()
|
||||
|
||||
async def handler(req):
|
||||
return ToolMessage(
|
||||
content="failed",
|
||||
tool_call_id="abc",
|
||||
status="error",
|
||||
)
|
||||
|
||||
request = MagicMock()
|
||||
request.tool = MagicMock()
|
||||
request.tool.name = "web_search"
|
||||
await mw.awrap_tool_call(request, handler)
|
||||
assert errors == ["web_search"]
|
||||
finally:
|
||||
ot.reload_for_tests()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue