plano/cli/planoai/obs/collector.py
2026-04-17 14:03:47 -07:00

266 lines
8.4 KiB
Python

"""In-memory collector for LLM calls, fed by OTLP/gRPC spans from brightstaff."""
from __future__ import annotations
import threading
from collections import deque
from concurrent import futures
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any, Iterable
import grpc
from opentelemetry.proto.collector.trace.v1 import (
trace_service_pb2,
trace_service_pb2_grpc,
)
DEFAULT_GRPC_PORT = 4317
DEFAULT_CAPACITY = 1000
@dataclass
class LLMCall:
"""One LLM call as reconstructed from a brightstaff LLM span.
Fields default to ``None`` when the underlying span attribute was absent.
"""
request_id: str
timestamp: datetime
model: str
provider: str | None = None
request_model: str | None = None
session_id: str | None = None
route_name: str | None = None
is_streaming: bool | None = None
status_code: int | None = None
prompt_tokens: int | None = None
completion_tokens: int | None = None
total_tokens: int | None = None
cached_input_tokens: int | None = None
cache_creation_tokens: int | None = None
reasoning_tokens: int | None = None
ttft_ms: float | None = None
duration_ms: float | None = None
routing_strategy: str | None = None
routing_reason: str | None = None
cost_usd: float | None = None
@property
def tpt_ms(self) -> float | None:
if self.duration_ms is None or self.completion_tokens in (None, 0):
return None
ttft = self.ttft_ms or 0.0
generate_ms = max(0.0, self.duration_ms - ttft)
if generate_ms <= 0:
return None
return generate_ms / self.completion_tokens
@property
def tokens_per_sec(self) -> float | None:
tpt = self.tpt_ms
if tpt is None or tpt <= 0:
return None
return 1000.0 / tpt
class LLMCallStore:
"""Thread-safe ring buffer of recent LLM calls."""
def __init__(self, capacity: int = DEFAULT_CAPACITY) -> None:
self._capacity = capacity
self._calls: deque[LLMCall] = deque(maxlen=capacity)
self._lock = threading.Lock()
@property
def capacity(self) -> int:
return self._capacity
def add(self, call: LLMCall) -> None:
with self._lock:
self._calls.append(call)
def clear(self) -> None:
with self._lock:
self._calls.clear()
def snapshot(self) -> list[LLMCall]:
with self._lock:
return list(self._calls)
def __len__(self) -> int:
with self._lock:
return len(self._calls)
# Span attribute keys used below are the canonical OTel / Plano keys emitted by
# brightstaff — see crates/brightstaff/src/tracing/constants.rs for the source
# of truth.
def _anyvalue_to_python(value: Any) -> Any: # AnyValue from OTLP
kind = value.WhichOneof("value")
if kind == "string_value":
return value.string_value
if kind == "bool_value":
return value.bool_value
if kind == "int_value":
return value.int_value
if kind == "double_value":
return value.double_value
return None
def _attrs_to_dict(attrs: Iterable[Any]) -> dict[str, Any]:
out: dict[str, Any] = {}
for kv in attrs:
py = _anyvalue_to_python(kv.value)
if py is not None:
out[kv.key] = py
return out
def _maybe_int(value: Any) -> int | None:
if value is None:
return None
try:
return int(value)
except (TypeError, ValueError):
return None
def _maybe_float(value: Any) -> float | None:
if value is None:
return None
try:
return float(value)
except (TypeError, ValueError):
return None
def span_to_llm_call(
span: Any, service_name: str, pricing: Any | None = None
) -> LLMCall | None:
"""Convert an OTLP span into an LLMCall, or return None if it isn't one.
A span is considered an LLM call iff it carries the ``llm.model`` attribute.
"""
attrs = _attrs_to_dict(span.attributes)
model = attrs.get("llm.model")
if not model:
return None
# Prefer explicit span attributes; fall back to likely aliases.
request_id = next(
(
str(attrs[key])
for key in ("request_id", "http.request_id")
if key in attrs and attrs[key] is not None
),
span.span_id.hex() if span.span_id else "",
)
start_ns = span.start_time_unix_nano or 0
ts = (
datetime.fromtimestamp(start_ns / 1_000_000_000, tz=timezone.utc).astimezone()
if start_ns
else datetime.now().astimezone()
)
call = LLMCall(
request_id=str(request_id),
timestamp=ts,
model=str(model),
provider=(
str(attrs["llm.provider"]) if "llm.provider" in attrs else service_name
),
request_model=(
str(attrs["model.requested"]) if "model.requested" in attrs else None
),
session_id=(
str(attrs["plano.session_id"]) if "plano.session_id" in attrs else None
),
route_name=(
str(attrs["plano.route.name"]) if "plano.route.name" in attrs else None
),
is_streaming=(
bool(attrs["llm.is_streaming"]) if "llm.is_streaming" in attrs else None
),
status_code=_maybe_int(attrs.get("http.status_code")),
prompt_tokens=_maybe_int(attrs.get("llm.usage.prompt_tokens")),
completion_tokens=_maybe_int(attrs.get("llm.usage.completion_tokens")),
total_tokens=_maybe_int(attrs.get("llm.usage.total_tokens")),
cached_input_tokens=_maybe_int(attrs.get("llm.usage.cached_input_tokens")),
cache_creation_tokens=_maybe_int(attrs.get("llm.usage.cache_creation_tokens")),
reasoning_tokens=_maybe_int(attrs.get("llm.usage.reasoning_tokens")),
ttft_ms=_maybe_float(attrs.get("llm.time_to_first_token")),
duration_ms=_maybe_float(attrs.get("llm.duration_ms")),
routing_strategy=(
str(attrs["routing.strategy"]) if "routing.strategy" in attrs else None
),
routing_reason=(
str(attrs["routing.selection_reason"])
if "routing.selection_reason" in attrs
else None
),
)
if pricing is not None:
call.cost_usd = pricing.cost_for_call(call)
return call
class _ObsServicer(trace_service_pb2_grpc.TraceServiceServicer):
def __init__(self, store: LLMCallStore, pricing: Any | None) -> None:
self._store = store
self._pricing = pricing
def Export(self, request, context): # noqa: N802 — gRPC generated name
for resource_spans in request.resource_spans:
service_name = "unknown"
for attr in resource_spans.resource.attributes:
if attr.key == "service.name":
val = _anyvalue_to_python(attr.value)
if val is not None:
service_name = str(val)
break
for scope_spans in resource_spans.scope_spans:
for span in scope_spans.spans:
call = span_to_llm_call(span, service_name, self._pricing)
if call is not None:
self._store.add(call)
return trace_service_pb2.ExportTraceServiceResponse()
@dataclass
class ObsCollector:
"""Owns the OTLP/gRPC server and the in-memory LLMCall ring buffer."""
store: LLMCallStore = field(default_factory=LLMCallStore)
pricing: Any | None = None
host: str = "0.0.0.0"
port: int = DEFAULT_GRPC_PORT
_server: grpc.Server | None = field(default=None, init=False, repr=False)
def start(self) -> None:
if self._server is not None:
return
server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
trace_service_pb2_grpc.add_TraceServiceServicer_to_server(
_ObsServicer(self.store, self.pricing), server
)
address = f"{self.host}:{self.port}"
bound = server.add_insecure_port(address)
if bound == 0:
raise OSError(
f"Failed to bind OTLP listener on {address}: port already in use. "
"Stop tracing via `planoai trace down` or pick another port with --port."
)
server.start()
self._server = server
def stop(self, grace: float = 2.0) -> None:
if self._server is not None:
self._server.stop(grace)
self._server = None