feat(observability): add OpenTelemetry process bootstrap

This commit is contained in:
Anish Sarkar 2026-05-21 23:01:54 +05:30
parent 60049936e3
commit eb2e2b253b
7 changed files with 413 additions and 21 deletions

View file

@ -1,4 +1,5 @@
import asyncio import asyncio
import contextlib
import gc import gc
import logging import logging
import time import time
@ -36,13 +37,15 @@ from app.config import (
) )
from app.db import User, create_db_and_tables, get_async_session from app.db import User, create_db_and_tables, get_async_session
from app.exceptions import GENERIC_5XX_MESSAGE, ISSUES_URL, SurfSenseError from app.exceptions import GENERIC_5XX_MESSAGE, ISSUES_URL, SurfSenseError
from app.observability import metrics as ot_metrics
from app.observability.bootstrap import init_otel, shutdown_otel
from app.rate_limiter import get_real_client_ip, limiter from app.rate_limiter import get_real_client_ip, limiter
from app.routes import router as crud_router from app.routes import router as crud_router
from app.routes.auth_routes import router as auth_router from app.routes.auth_routes import router as auth_router
from app.schemas import UserCreate, UserRead, UserUpdate from app.schemas import UserCreate, UserRead, UserUpdate
from app.tasks.surfsense_docs_indexer import seed_surfsense_docs from app.tasks.surfsense_docs_indexer import seed_surfsense_docs
from app.users import SECRET, auth_backend, current_active_user, fastapi_users from app.users import SECRET, auth_backend, current_active_user, fastapi_users
from app.utils.perf import get_perf_logger, log_system_snapshot from app.utils.perf import log_system_snapshot
_error_logger = logging.getLogger("surfsense.errors") _error_logger = logging.getLogger("surfsense.errors")
@ -127,6 +130,8 @@ def _http_exception_handler(request: Request, exc: HTTPException) -> JSONRespons
logged server-side. logged server-side.
""" """
rid = _get_request_id(request) rid = _get_request_id(request)
if exc.status_code in {401, 403} and request.url.path.startswith("/auth"):
ot_metrics.record_auth_failure(reason=_status_to_code(exc.status_code))
should_sanitize = exc.status_code == 500 should_sanitize = exc.status_code == 500
# Structured dict details (e.g. {"code": "CAPTCHA_REQUIRED", "message": "..."}) # Structured dict details (e.g. {"code": "CAPTCHA_REQUIRED", "message": "..."})
@ -213,6 +218,7 @@ def _validation_error_handler(
def _unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse: def _unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse:
"""Catch-all: log full traceback, return sanitized 500.""" """Catch-all: log full traceback, return sanitized 500."""
rid = _get_request_id(request) rid = _get_request_id(request)
ot_metrics.record_auth_failure(reason="unhandled_exception")
_error_logger.error( _error_logger.error(
"[%s] Unhandled exception on %s %s", "[%s] Unhandled exception on %s %s",
rid, rid,
@ -246,6 +252,7 @@ def _status_to_code(status_code: int, detail: str = "") -> str:
def _rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded): def _rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded):
"""Custom 429 handler that returns JSON matching our error envelope.""" """Custom 429 handler that returns JSON matching our error envelope."""
rid = _get_request_id(request) rid = _get_request_id(request)
ot_metrics.record_rate_limit_rejection(scope="slowapi")
retry_after = exc.detail.split("per")[-1].strip() if exc.detail else "60" retry_after = exc.detail.split("per")[-1].strip() if exc.detail else "60"
return _build_error_response( return _build_error_response(
429, 429,
@ -306,6 +313,7 @@ def _check_rate_limit_memory(
f"Rate limit exceeded (in-memory fallback) on {scope} for IP {client_ip} " f"Rate limit exceeded (in-memory fallback) on {scope} for IP {client_ip} "
f"({len(timestamps)}/{max_requests} in {window_seconds}s)" f"({len(timestamps)}/{max_requests} in {window_seconds}s)"
) )
ot_metrics.record_rate_limit_rejection(scope=scope)
raise HTTPException( raise HTTPException(
status_code=429, status_code=429,
detail="RATE_LIMIT_EXCEEDED", detail="RATE_LIMIT_EXCEEDED",
@ -349,6 +357,7 @@ def _check_rate_limit(
f"Rate limit exceeded on {scope} for IP {client_ip} " f"Rate limit exceeded on {scope} for IP {client_ip} "
f"({current_count}/{max_requests} in {window_seconds}s)" f"({current_count}/{max_requests} in {window_seconds}s)"
) )
ot_metrics.record_rate_limit_rejection(scope=scope)
raise HTTPException( raise HTTPException(
status_code=429, status_code=429,
detail="RATE_LIMIT_EXCEEDED", detail="RATE_LIMIT_EXCEEDED",
@ -558,6 +567,7 @@ async def lifespan(app: FastAPI):
gc.set_threshold(700, 10, 5) gc.set_threshold(700, 10, 5)
_enable_slow_callback_logging(threshold_sec=0.5) _enable_slow_callback_logging(threshold_sec=0.5)
init_otel(app)
await create_db_and_tables() await create_db_and_tables()
await setup_checkpointer_tables() await setup_checkpointer_tables()
initialize_openrouter_integration() initialize_openrouter_integration()
@ -592,6 +602,7 @@ async def lifespan(app: FastAPI):
_stop_openrouter_background_refresh() _stop_openrouter_background_refresh()
await close_checkpointer() await close_checkpointer()
shutdown_otel()
def registration_allowed(): def registration_allowed():
@ -676,32 +687,20 @@ class RequestPerfMiddleware(BaseHTTPMiddleware):
async def dispatch( async def dispatch(
self, request: StarletteRequest, call_next: RequestResponseEndpoint self, request: StarletteRequest, call_next: RequestResponseEndpoint
) -> StarletteResponse: ) -> StarletteResponse:
perf = get_perf_logger()
t0 = time.perf_counter() t0 = time.perf_counter()
response = await call_next(request) response = await call_next(request)
elapsed_ms = (time.perf_counter() - t0) * 1000 elapsed_ms = (time.perf_counter() - t0) * 1000
path = request.url.path path = request.url.path
method = request.method
status = response.status_code
perf.debug(
"[request] %s %s -> %d in %.1fms",
method,
path,
status,
elapsed_ms,
)
if elapsed_ms > _PERF_SLOW_REQUEST_THRESHOLD: if elapsed_ms > _PERF_SLOW_REQUEST_THRESHOLD:
perf.warning( with contextlib.suppress(Exception):
"[SLOW_REQUEST] %s %s -> %d in %.1fms (threshold=%.0fms)", from opentelemetry import trace
method,
path, span = trace.get_current_span()
status, span.set_attribute("slow_request", True)
elapsed_ms, span.set_attribute("surfsense.request.elapsed_ms", elapsed_ms)
_PERF_SLOW_REQUEST_THRESHOLD, span.set_attribute("http.route", path)
)
log_system_snapshot("slow_request") log_system_snapshot("slow_request")
return response return response

View file

@ -18,6 +18,10 @@ def init_worker(**kwargs):
This ensures the Auto mode (LiteLLM Router) is available for background tasks This ensures the Auto mode (LiteLLM Router) is available for background tasks
like document summarization and image generation. like document summarization and image generation.
""" """
from app.observability.bootstrap import init_otel
init_otel(app=None, traces=True, metrics=True, logs=True)
from app.config import ( from app.config import (
initialize_image_gen_router, initialize_image_gen_router,
initialize_llm_router, initialize_llm_router,

View file

@ -5,3 +5,5 @@ small wrapper around the optional ``opentelemetry`` instrumentation. The
wrapper is a no-op when OTEL is not configured, so importing it from wrapper is a no-op when OTEL is not configured, so importing it from
performance-critical paths is safe. performance-critical paths is safe.
""" """
__all__ = ["bootstrap", "metrics", "otel"]

View file

@ -0,0 +1,361 @@
"""Programmatic OpenTelemetry bootstrap for SurfSense backend processes."""
from __future__ import annotations
import contextlib
import logging
import os
import socket
from importlib import metadata
from typing import Any
from app.observability import otel
logger = logging.getLogger(__name__)
_BOOL_TRUE = {"1", "true", "yes", "on"}
_TRACES_INITIALIZED = False
_METRICS_INITIALIZED = False
_LOGS_INITIALIZED = False
_FASTAPI_INSTRUMENTED = False
_SQLALCHEMY_INSTRUMENTED = False
_PSYCOPG_INSTRUMENTED = False
_REDIS_INSTRUMENTED = False
_HTTPX_INSTRUMENTED = False
_CELERY_INSTRUMENTED = False
_TRACER_PROVIDER: Any | None = None
_METER_PROVIDER: Any | None = None
def _env_truthy(name: str) -> bool:
return os.environ.get(name, "").strip().lower() in _BOOL_TRUE
def is_otel_disabled() -> bool:
"""Return true when either SurfSense or OTel's spec kill switch is set."""
return _env_truthy("SURFSENSE_DISABLE_OTEL") or _env_truthy("OTEL_SDK_DISABLED")
def is_otel_configured() -> bool:
"""Return true when this process should export OTel signals."""
return bool(
os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT")
or os.environ.get("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT")
or os.environ.get("OTEL_EXPORTER_OTLP_METRICS_ENDPOINT")
)
def _package_version() -> str:
with contextlib.suppress(metadata.PackageNotFoundError):
return metadata.version("surf-new-backend")
return "unknown"
def _deployment_environment() -> str:
return (
os.environ.get("SURFSENSE_ENV")
or os.environ.get("APP_ENV")
or os.environ.get("ENVIRONMENT")
or "dev"
)
def _build_resource():
from opentelemetry.sdk.resources import Resource
return Resource.create(
{
"service.name": os.environ.get("OTEL_SERVICE_NAME", "surfsense-backend"),
"service.version": _package_version(),
"service.instance.id": socket.gethostname(),
"deployment.environment": _deployment_environment(),
}
)
def _otlp_protocol() -> str:
return os.environ.get("OTEL_EXPORTER_OTLP_PROTOCOL", "grpc").strip().lower()
def _trace_exporter():
if _otlp_protocol() == "http/protobuf":
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
OTLPSpanExporter,
)
endpoint = os.environ.get("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT")
return OTLPSpanExporter(endpoint=endpoint) if endpoint else OTLPSpanExporter()
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
endpoint = os.environ.get("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT")
return OTLPSpanExporter(endpoint=endpoint) if endpoint else OTLPSpanExporter()
def _metric_exporter():
if _otlp_protocol() == "http/protobuf":
from opentelemetry.exporter.otlp.proto.http.metric_exporter import (
OTLPMetricExporter,
)
endpoint = os.environ.get("OTEL_EXPORTER_OTLP_METRICS_ENDPOINT")
return (
OTLPMetricExporter(endpoint=endpoint) if endpoint else OTLPMetricExporter()
)
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import (
OTLPMetricExporter,
)
endpoint = os.environ.get("OTEL_EXPORTER_OTLP_METRICS_ENDPOINT")
return OTLPMetricExporter(endpoint=endpoint) if endpoint else OTLPMetricExporter()
def _safe_instrument(name: str, instrument: Any) -> bool:
try:
instrument()
except Exception:
logger.warning("OpenTelemetry %s instrumentation failed", name, exc_info=True)
return False
return True
def _instrument_fastapi(app: Any | None) -> None:
global _FASTAPI_INSTRUMENTED
if app is None or _FASTAPI_INSTRUMENTED:
return
def _run() -> None:
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
FastAPIInstrumentor.instrument_app(
app,
excluded_urls="/health,/ready,/metrics",
)
if _safe_instrument("FastAPI", _run):
_FASTAPI_INSTRUMENTED = True
def instrument_sqlalchemy_engine(engine: Any) -> None:
"""Instrument a SQLAlchemy engine once per process."""
global _SQLALCHEMY_INSTRUMENTED
if _SQLALCHEMY_INSTRUMENTED:
return
def _run() -> None:
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
SQLAlchemyInstrumentor().instrument(
engine=getattr(engine, "sync_engine", engine),
enable_commenter=True,
)
if _safe_instrument("SQLAlchemy", _run):
_SQLALCHEMY_INSTRUMENTED = True
def _instrument_sqlalchemy() -> None:
if _SQLALCHEMY_INSTRUMENTED:
return
with contextlib.suppress(Exception):
from app.db import engine
instrument_sqlalchemy_engine(engine)
def _instrument_psycopg() -> None:
global _PSYCOPG_INSTRUMENTED
if _PSYCOPG_INSTRUMENTED:
return
def _run() -> None:
from opentelemetry.instrumentation.psycopg import PsycopgInstrumentor
PsycopgInstrumentor().instrument()
if _safe_instrument("psycopg", _run):
_PSYCOPG_INSTRUMENTED = True
def _instrument_redis() -> None:
global _REDIS_INSTRUMENTED
if _REDIS_INSTRUMENTED:
return
def _run() -> None:
from opentelemetry.instrumentation.redis import RedisInstrumentor
RedisInstrumentor().instrument()
if _safe_instrument("Redis", _run):
_REDIS_INSTRUMENTED = True
def _instrument_httpx() -> None:
global _HTTPX_INSTRUMENTED
if _HTTPX_INSTRUMENTED:
return
def _run() -> None:
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
HTTPXClientInstrumentor().instrument()
if _safe_instrument("HTTPX", _run):
_HTTPX_INSTRUMENTED = True
def instrument_celery() -> None:
"""Instrument Celery producer/consumer hooks once per process."""
global _CELERY_INSTRUMENTED
if _CELERY_INSTRUMENTED:
return
def _run() -> None:
from opentelemetry.instrumentation.celery import CeleryInstrumentor
CeleryInstrumentor().instrument()
if _safe_instrument("Celery", _run):
_CELERY_INSTRUMENTED = True
def _instrument_libraries(app: Any | None) -> None:
_instrument_fastapi(app)
_instrument_sqlalchemy()
_instrument_psycopg()
_instrument_redis()
_instrument_httpx()
instrument_celery()
def init_traces(app: Any | None = None) -> None:
"""Install the tracer provider, span processor, exporter, and instrumentors."""
global _TRACER_PROVIDER, _TRACES_INITIALIZED
if _TRACES_INITIALIZED:
_instrument_fastapi(app)
return
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.sdk.trace.sampling import ALWAYS_ON, ParentBased
provider = TracerProvider(
resource=_build_resource(),
sampler=ParentBased(ALWAYS_ON),
)
provider.add_span_processor(BatchSpanProcessor(_trace_exporter()))
try:
trace.set_tracer_provider(provider)
except Exception:
logger.warning(
"OpenTelemetry tracer provider was already set; reusing existing provider",
exc_info=True,
)
_TRACER_PROVIDER = trace.get_tracer_provider()
else:
_TRACER_PROVIDER = provider
_TRACES_INITIALIZED = True
otel.reload_for_tests()
_instrument_libraries(app)
def init_metrics() -> None:
"""Install the meter provider, metric reader, exporter, and custom gauges."""
global _METER_PROVIDER, _METRICS_INITIALIZED
if _METRICS_INITIALIZED:
return
from opentelemetry import metrics
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
interval_ms = int(os.environ.get("OTEL_METRIC_EXPORT_INTERVAL", "60000"))
reader = PeriodicExportingMetricReader(
_metric_exporter(),
export_interval_millis=interval_ms,
)
provider = MeterProvider(metric_readers=[reader], resource=_build_resource())
try:
metrics.set_meter_provider(provider)
except Exception:
logger.warning(
"OpenTelemetry meter provider was already set; reusing existing provider",
exc_info=True,
)
_METER_PROVIDER = metrics.get_meter_provider()
else:
_METER_PROVIDER = provider
_METRICS_INITIALIZED = True
from app.observability.metrics import register_runtime_observables
register_runtime_observables()
def init_logs() -> None:
"""Enable trace/span correlation fields on stdlib LogRecords."""
global _LOGS_INITIALIZED
if _LOGS_INITIALIZED:
return
def _run() -> None:
from opentelemetry.instrumentation.logging import LoggingInstrumentor
LoggingInstrumentor().instrument()
if _safe_instrument("logging", _run):
_LOGS_INITIALIZED = True
def init_otel(
app: Any | None = None,
*,
traces: bool = True,
metrics: bool = True,
logs: bool = True,
) -> None:
"""Initialize OpenTelemetry for a FastAPI or Celery process."""
if is_otel_disabled() or not is_otel_configured():
otel.reload_for_tests()
return
if traces:
init_traces(app)
if metrics:
init_metrics()
if logs:
init_logs()
def shutdown_otel(timeout_millis: int = 5000) -> None:
"""Best-effort flush and shutdown for installed providers."""
for provider in (_TRACER_PROVIDER, _METER_PROVIDER):
if provider is None:
continue
with contextlib.suppress(Exception):
provider.force_flush(timeout_millis=timeout_millis)
with contextlib.suppress(Exception):
provider.shutdown()
__all__ = [
"_BOOL_TRUE",
"_build_resource",
"init_logs",
"init_metrics",
"init_otel",
"init_traces",
"instrument_celery",
"instrument_sqlalchemy_engine",
"is_otel_configured",
"is_otel_disabled",
"shutdown_otel",
]

View file

@ -66,6 +66,8 @@ def _resolve_enabled() -> bool:
# Honor an explicit kill-switch first. # Honor an explicit kill-switch first.
if os.environ.get("SURFSENSE_DISABLE_OTEL", "").lower() in {"1", "true", "yes"}: if os.environ.get("SURFSENSE_DISABLE_OTEL", "").lower() in {"1", "true", "yes"}:
return False return False
if os.environ.get("OTEL_SDK_DISABLED", "").lower() in {"1", "true", "yes", "on"}:
return False
# Treat a configured endpoint as the canonical "OTel is wired up" signal. # Treat a configured endpoint as the canonical "OTel is wired up" signal.
if os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT"): if os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT"):
return True return True
@ -198,8 +200,11 @@ def model_call_span(
attrs: dict[str, Any] = {} attrs: dict[str, Any] = {}
if model_id: if model_id:
attrs["model.id"] = model_id attrs["model.id"] = model_id
attrs["gen_ai.request.model"] = model_id
if provider: if provider:
attrs["model.provider"] = provider attrs["model.provider"] = provider
attrs["gen_ai.provider.name"] = provider
attrs["gen_ai.operation.name"] = "chat"
if extra: if extra:
attrs.update(extra) attrs.update(extra)
return span("model.call", attributes=attrs) return span("model.call", attributes=attrs)

View file

@ -37,6 +37,10 @@ def get_celery_session_maker() -> async_sessionmaker:
poolclass=NullPool, poolclass=NullPool,
echo=False, echo=False,
) )
with contextlib.suppress(Exception):
from app.observability.bootstrap import instrument_sqlalchemy_engine
instrument_sqlalchemy_engine(_celery_engine)
_celery_session_maker = async_sessionmaker( _celery_session_maker = async_sessionmaker(
_celery_engine, expire_on_commit=False _celery_engine, expire_on_commit=False
) )

View file

@ -12,9 +12,26 @@ if sys.platform == "win32":
from app.config.uvicorn import load_uvicorn_config from app.config.uvicorn import load_uvicorn_config
_old_log_record_factory = logging.getLogRecordFactory()
def _otel_safe_log_record_factory(*args, **kwargs):
record = _old_log_record_factory(*args, **kwargs)
if not hasattr(record, "otelTraceID"):
record.otelTraceID = "0"
if not hasattr(record, "otelSpanID"):
record.otelSpanID = "0"
return record
logging.setLogRecordFactory(_otel_safe_log_record_factory)
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", format=(
"%(asctime)s - %(name)s - %(levelname)s - "
"[trace_id=%(otelTraceID)s span_id=%(otelSpanID)s] %(message)s"
),
datefmt="%Y-%m-%d %H:%M:%S", datefmt="%Y-%m-%d %H:%M:%S",
) )