feat(observability): add SurfSense metric helpers

This commit is contained in:
Anish Sarkar 2026-05-21 23:02:20 +05:30
parent eb2e2b253b
commit 6095b48b5f
3 changed files with 581 additions and 2 deletions

View file

@ -0,0 +1,416 @@
"""Custom OpenTelemetry metrics for SurfSense.
This module owns all SurfSense-specific metric instruments. Callers use the
small helper functions below instead of constructing instruments directly so
attribute names and cardinality stay consistent across the backend.
"""
from __future__ import annotations
import contextlib
import gc
import logging
from functools import lru_cache
from importlib import metadata
from typing import Any
from app.observability import otel
logger = logging.getLogger(__name__)
_INSTRUMENTATION_NAME = "surfsense.platform"
_OBSERVABLES_REGISTERED = False
def _package_version() -> str:
with contextlib.suppress(metadata.PackageNotFoundError):
return metadata.version("surf-new-backend")
return "unknown"
def _is_enabled() -> bool:
return otel.is_enabled()
def _clean_attrs(attrs: dict[str, Any]) -> dict[str, str | int | float | bool]:
"""Drop empty values and coerce low-cardinality attrs to OTel-safe scalars."""
cleaned: dict[str, str | int | float | bool] = {}
for key, value in attrs.items():
if value is None:
continue
if isinstance(value, bool | int | float):
cleaned[key] = value
continue
text = str(value)
if text:
cleaned[key] = text
return cleaned
def _record(callable_obj: Any, value: int | float, attrs: dict[str, Any]) -> None:
if not _is_enabled():
return
with contextlib.suppress(Exception):
callable_obj.record(value, _clean_attrs(attrs))
def _add(callable_obj: Any, value: int, attrs: dict[str, Any]) -> None:
if not _is_enabled():
return
with contextlib.suppress(Exception):
callable_obj.add(value, _clean_attrs(attrs))
@lru_cache(maxsize=1)
def _get_meter():
from opentelemetry import metrics
return metrics.get_meter(_INSTRUMENTATION_NAME, _package_version())
@lru_cache(maxsize=1)
def _model_call_duration():
return _get_meter().create_histogram(
"surfsense.model.call.duration",
unit="ms",
description="Duration of SurfSense LLM model calls.",
)
@lru_cache(maxsize=1)
def _model_token_usage():
return _get_meter().create_histogram(
"gen_ai.client.token.usage",
unit="{token}",
description="Token usage reported by GenAI model responses.",
)
@lru_cache(maxsize=1)
def _tool_call_duration():
return _get_meter().create_histogram(
"surfsense.tool.call.duration",
unit="ms",
description="Duration of SurfSense agent tool calls.",
)
@lru_cache(maxsize=1)
def _tool_call_errors():
return _get_meter().create_counter(
"surfsense.tool.call.errors",
description="Count of SurfSense agent tool call errors.",
)
@lru_cache(maxsize=1)
def _kb_search_duration():
return _get_meter().create_histogram(
"surfsense.kb.search.duration",
unit="ms",
description="Duration of SurfSense knowledge-base search calls.",
)
@lru_cache(maxsize=1)
def _compaction_runs():
return _get_meter().create_counter(
"surfsense.compaction.runs",
description="Count of SurfSense conversation compaction runs.",
)
@lru_cache(maxsize=1)
def _permission_asks():
return _get_meter().create_counter(
"surfsense.permission.asks",
description="Count of SurfSense permission asks.",
)
@lru_cache(maxsize=1)
def _interrupts():
return _get_meter().create_counter(
"surfsense.interrupt.raised",
description="Count of SurfSense interrupts raised.",
)
@lru_cache(maxsize=1)
def _indexing_document_duration():
return _get_meter().create_histogram(
"surfsense.indexing.document.duration",
unit="s",
description="Duration of SurfSense document indexing.",
)
@lru_cache(maxsize=1)
def _indexing_document_outcome():
return _get_meter().create_counter(
"surfsense.indexing.document.outcome",
description="Count of SurfSense document indexing outcomes.",
)
@lru_cache(maxsize=1)
def _connector_sync_duration():
return _get_meter().create_histogram(
"surfsense.connector.sync.duration",
unit="s",
description="Duration of SurfSense connector sync tasks.",
)
@lru_cache(maxsize=1)
def _connector_sync_outcome():
return _get_meter().create_counter(
"surfsense.connector.sync.outcome",
description="Count of SurfSense connector sync outcomes.",
)
@lru_cache(maxsize=1)
def _auth_failures():
return _get_meter().create_counter(
"surfsense.auth.failures",
description="Count of SurfSense authentication failures.",
)
@lru_cache(maxsize=1)
def _rate_limit_rejections():
return _get_meter().create_counter(
"surfsense.rate_limit.rejections",
description="Count of SurfSense rate-limit rejections.",
)
@lru_cache(maxsize=1)
def _perf_elapsed():
return _get_meter().create_histogram(
"surfsense.perf.elapsed_ms",
unit="ms",
description="Elapsed time recorded by SurfSense perf timers.",
)
def record_model_call_duration(
duration_ms: float, *, model: str | None, provider: str | None
) -> None:
_record(
_model_call_duration(),
duration_ms,
{
"gen_ai.request.model": model,
"gen_ai.provider.name": provider,
},
)
def record_model_token_usage(
*,
input_tokens: int | None,
output_tokens: int | None,
model: str | None,
provider: str | None,
) -> None:
base = {
"gen_ai.request.model": model,
"gen_ai.provider.name": provider,
"gen_ai.operation.name": "chat",
}
if input_tokens is not None:
_record(
_model_token_usage(),
int(input_tokens),
{**base, "gen_ai.token.type": "input"},
)
if output_tokens is not None:
_record(
_model_token_usage(),
int(output_tokens),
{**base, "gen_ai.token.type": "output"},
)
def record_tool_call_duration(duration_ms: float, *, tool_name: str) -> None:
_record(_tool_call_duration(), duration_ms, {"tool.name": tool_name})
def record_tool_call_error(*, tool_name: str) -> None:
_add(_tool_call_errors(), 1, {"tool.name": tool_name})
def record_kb_search_duration(
duration_ms: float, *, search_space_id: int | None, surface: str
) -> None:
_record(
_kb_search_duration(),
duration_ms,
{"search_space.id": search_space_id, "search.surface": surface},
)
def record_compaction_run(*, reason: str | None) -> None:
_add(_compaction_runs(), 1, {"compaction.reason": reason or "unknown"})
def record_permission_ask(*, permission: str) -> None:
_add(_permission_asks(), 1, {"permission.permission": permission})
def record_interrupt(*, interrupt_type: str) -> None:
_add(_interrupts(), 1, {"interrupt.type": interrupt_type})
def record_indexing_document_duration(
duration_s: float, *, document_type: str | None
) -> None:
_record(
_indexing_document_duration(),
duration_s,
{"document.type": document_type or "unknown"},
)
def record_indexing_document_outcome(*, document_type: str | None, status: str) -> None:
_add(
_indexing_document_outcome(),
1,
{"document.type": document_type or "unknown", "status": status},
)
def record_connector_sync_duration(
duration_s: float, *, connector_type: str | None
) -> None:
_record(
_connector_sync_duration(),
duration_s,
{"connector.type": connector_type or "unknown"},
)
def record_connector_sync_outcome(*, connector_type: str | None, status: str) -> None:
_add(
_connector_sync_outcome(),
1,
{"connector.type": connector_type or "unknown", "status": status},
)
def record_auth_failure(*, reason: str) -> None:
_add(_auth_failures(), 1, {"reason": reason})
def record_rate_limit_rejection(*, scope: str) -> None:
_add(_rate_limit_rejections(), 1, {"scope": scope})
def record_perf_elapsed(duration_ms: float, *, label: str) -> None:
_record(_perf_elapsed(), duration_ms, {"label": label})
def _runtime_snapshot_value(key: str, transform: Any = None) -> list[Any]:
from opentelemetry.metrics import Observation
from app.utils.perf import system_snapshot
snap = system_snapshot()
value = snap.get(key)
if not isinstance(value, int | float) or value < 0:
return []
if transform is not None:
value = transform(value)
return [Observation(value)]
def _observe_gc_collections(_options: Any) -> list[Any]:
from opentelemetry.metrics import Observation
return [
Observation(count, {"generation": str(generation)})
for generation, count in enumerate(gc.get_count())
]
def register_runtime_observables() -> None:
"""Register process/runtime observable gauges once per process."""
global _OBSERVABLES_REGISTERED
if _OBSERVABLES_REGISTERED or not _is_enabled():
return
meter = _get_meter()
try:
# Each callback returns the value for a single gauge except GC, whose
# callback carries a generation attribute.
meter.create_observable_gauge(
"process.runtime.cpython.memory.rss",
callbacks=[
lambda _options: _runtime_snapshot_value(
"rss_mb", lambda v: float(v) * 1024 * 1024
)
],
unit="By",
description="Resident set size of the SurfSense backend process.",
)
meter.create_observable_gauge(
"process.runtime.cpython.cpu.utilization",
callbacks=[
lambda _options: _runtime_snapshot_value(
"cpu_percent", lambda v: float(v) / 100.0
)
],
unit="1",
description="CPU utilization of the SurfSense backend process.",
)
meter.create_observable_gauge(
"process.runtime.cpython.threads",
callbacks=[lambda _options: _runtime_snapshot_value("threads")],
unit="{thread}",
description="Thread count of the SurfSense backend process.",
)
meter.create_observable_gauge(
"process.runtime.cpython.open_fds",
callbacks=[lambda _options: _runtime_snapshot_value("open_fds")],
unit="{fd}",
description="Open file descriptor count of the SurfSense backend process.",
)
meter.create_observable_gauge(
"python.asyncio.tasks",
callbacks=[lambda _options: _runtime_snapshot_value("asyncio_tasks")],
unit="{task}",
description="Live asyncio task count in the current event loop.",
)
meter.create_observable_gauge(
"process.runtime.cpython.gc.collections",
callbacks=[_observe_gc_collections],
unit="{collection}",
description="CPython GC counters by generation.",
)
except Exception:
logger.warning("Failed to register OTel runtime observables", exc_info=True)
return
_OBSERVABLES_REGISTERED = True
__all__ = [
"record_auth_failure",
"record_compaction_run",
"record_connector_sync_duration",
"record_connector_sync_outcome",
"record_indexing_document_duration",
"record_indexing_document_outcome",
"record_interrupt",
"record_kb_search_duration",
"record_model_call_duration",
"record_model_token_usage",
"record_perf_elapsed",
"record_permission_ask",
"record_rate_limit_rejection",
"record_tool_call_duration",
"record_tool_call_error",
"register_runtime_observables",
]

View file

@ -16,6 +16,8 @@ import time
from contextlib import asynccontextmanager, contextmanager
from typing import Any
from app.observability import metrics as ot_metrics
_perf_log: logging.Logger | None = None
_last_rss_mb: float = 0.0
@ -50,6 +52,7 @@ def perf_timer(label: str, *, extra: dict[str, Any] | None = None):
if extra:
suffix = " " + " ".join(f"{k}={v}" for k, v in extra.items())
log.info("%s in %.3fs%s", label, elapsed, suffix)
ot_metrics.record_perf_elapsed(elapsed * 1000, label=label)
@asynccontextmanager
@ -68,6 +71,7 @@ async def perf_async_timer(label: str, *, extra: dict[str, Any] | None = None):
if extra:
suffix = " " + " ".join(f"{k}={v}" for k, v in extra.items())
log.info("%s in %.3fs%s", label, elapsed, suffix)
ot_metrics.record_perf_elapsed(elapsed * 1000, label=label)
def system_snapshot() -> dict[str, Any]:

View file

@ -4,7 +4,7 @@ from __future__ import annotations
import pytest
from app.observability import otel
from app.observability import bootstrap, metrics, otel
pytestmark = pytest.mark.unit
@ -12,7 +12,14 @@ pytestmark = pytest.mark.unit
@pytest.fixture(autouse=True)
def _reset_otel_state(monkeypatch: pytest.MonkeyPatch):
"""Force a clean OTel disabled state per test, then restore after."""
for env in ("OTEL_EXPORTER_OTLP_ENDPOINT", "SURFSENSE_DISABLE_OTEL"):
for env in (
"OTEL_EXPORTER_OTLP_ENDPOINT",
"OTEL_EXPORTER_OTLP_PROTOCOL",
"OTEL_EXPORTER_OTLP_TRACES_ENDPOINT",
"OTEL_EXPORTER_OTLP_METRICS_ENDPOINT",
"SURFSENSE_DISABLE_OTEL",
"OTEL_SDK_DISABLED",
):
monkeypatch.delenv(env, raising=False)
monkeypatch.setenv("SURFSENSE_DISABLE_OTEL", "true")
otel.reload_for_tests()
@ -36,6 +43,158 @@ def test_kill_switch_overrides_endpoint(monkeypatch: pytest.MonkeyPatch) -> None
assert otel.reload_for_tests() is False
def test_spec_kill_switch_overrides_endpoint(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.delenv("SURFSENSE_DISABLE_OTEL", raising=False)
monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317")
monkeypatch.setenv("OTEL_SDK_DISABLED", "true")
assert otel.reload_for_tests() is False
class TestBootstrapConfig:
def test_disabled_checks_both_kill_switches(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.delenv("SURFSENSE_DISABLE_OTEL", raising=False)
monkeypatch.delenv("OTEL_SDK_DISABLED", raising=False)
assert bootstrap.is_otel_disabled() is False
monkeypatch.setenv("OTEL_SDK_DISABLED", "on")
assert bootstrap.is_otel_disabled() is True
def test_configured_by_shared_or_signal_endpoint(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.delenv("SURFSENSE_DISABLE_OTEL", raising=False)
assert bootstrap.is_otel_configured() is False
monkeypatch.setenv(
"OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", "http://localhost:4317"
)
assert bootstrap.is_otel_configured() is True
def test_init_otel_noops_when_disabled(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
called = {"traces": False}
def fake_init_traces(app=None):
del app
called["traces"] = True
monkeypatch.setenv("SURFSENSE_DISABLE_OTEL", "true")
monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317")
monkeypatch.setattr(bootstrap, "init_traces", fake_init_traces)
bootstrap.init_otel()
assert called["traces"] is False
def test_init_otel_dispatches_enabled_signals(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
called: list[str] = []
monkeypatch.delenv("SURFSENSE_DISABLE_OTEL", raising=False)
monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317")
monkeypatch.setattr(
bootstrap, "init_traces", lambda app=None: called.append("traces")
)
monkeypatch.setattr(bootstrap, "init_metrics", lambda: called.append("metrics"))
monkeypatch.setattr(bootstrap, "init_logs", lambda: called.append("logs"))
bootstrap.init_otel()
assert called == ["traces", "metrics", "logs"]
def test_resource_defaults_include_service_metadata(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.setenv("OTEL_SERVICE_NAME", "custom-backend")
monkeypatch.setenv("SURFSENSE_ENV", "test")
resource = bootstrap._build_resource()
attrs = dict(resource.attributes)
assert attrs["service.name"] == "custom-backend"
assert attrs["deployment.environment"] == "test"
assert attrs["service.instance.id"]
def test_shutdown_is_safe_without_providers(self) -> None:
bootstrap.shutdown_otel()
class TestMetricHelpers:
def test_all_metric_helpers_noop_safely_when_disabled(self) -> None:
metrics.record_model_call_duration(12.5, model="gpt-4o", provider="openai")
metrics.record_model_token_usage(
input_tokens=10,
output_tokens=5,
model="gpt-4o",
provider="openai",
)
metrics.record_tool_call_duration(3.0, tool_name="web_search")
metrics.record_tool_call_error(tool_name="web_search")
metrics.record_kb_search_duration(
4.0,
search_space_id=1,
surface="documents",
)
metrics.record_compaction_run(reason="auto")
metrics.record_permission_ask(permission="write_file")
metrics.record_interrupt(interrupt_type="permission_ask")
metrics.record_indexing_document_duration(1.2, document_type="FILE")
metrics.record_indexing_document_outcome(document_type="FILE", status="success")
metrics.record_connector_sync_duration(
2.3,
connector_type="index_notion_pages",
)
metrics.record_connector_sync_outcome(
connector_type="index_notion_pages",
status="success",
)
metrics.record_auth_failure(reason="UNAUTHORIZED")
metrics.record_rate_limit_rejection(scope="login")
metrics.record_perf_elapsed(7.0, label="[test]")
def test_runtime_observables_register_once(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
class FakeMeter:
def __init__(self) -> None:
self.names: list[str] = []
def create_observable_gauge(self, name: str, **kwargs) -> None:
del kwargs
self.names.append(name)
fake_meter = FakeMeter()
monkeypatch.setattr(metrics, "_OBSERVABLES_REGISTERED", False)
monkeypatch.setattr(metrics, "_is_enabled", lambda: True)
monkeypatch.setattr(metrics, "_get_meter", lambda: fake_meter)
metrics.register_runtime_observables()
metrics.register_runtime_observables()
assert len(fake_meter.names) == 6
assert fake_meter.names.count("python.asyncio.tasks") == 1
monkeypatch.setattr(metrics, "_OBSERVABLES_REGISTERED", False)
def test_log_record_factory_provides_zero_otel_fields() -> None:
import logging
import main # noqa: F401
record = logging.getLogRecordFactory()(
"test",
logging.INFO,
__file__,
1,
"hello",
(),
None,
)
assert record.otelTraceID == "0"
assert record.otelSpanID == "0"
class TestNoopSpansWhenDisabled:
def test_generic_span_yields_noop(self) -> None:
with otel.span("any.thing", attributes={"x": 1}) as sp: