black formatting

This commit is contained in:
Adil Hafeez 2026-04-17 01:02:12 -07:00
parent d30018cf35
commit d09fa97568
5 changed files with 52 additions and 13 deletions

View file

@ -15,7 +15,6 @@ from opentelemetry.proto.collector.trace.v1 import (
trace_service_pb2_grpc, trace_service_pb2_grpc,
) )
DEFAULT_GRPC_PORT = 4317 DEFAULT_GRPC_PORT = 4317
DEFAULT_CAPACITY = 1000 DEFAULT_CAPACITY = 1000
@ -198,9 +197,9 @@ def span_to_llm_call(
route_name=( route_name=(
str(attrs[_PLANO_ROUTE_NAME]) if _PLANO_ROUTE_NAME in attrs else None str(attrs[_PLANO_ROUTE_NAME]) if _PLANO_ROUTE_NAME in attrs else None
), ),
is_streaming=bool(attrs[_LLM_IS_STREAMING]) is_streaming=(
if _LLM_IS_STREAMING in attrs bool(attrs[_LLM_IS_STREAMING]) if _LLM_IS_STREAMING in attrs else None
else None, ),
status_code=_maybe_int(attrs.get(_HTTP_STATUS)), status_code=_maybe_int(attrs.get(_HTTP_STATUS)),
prompt_tokens=_maybe_int(attrs.get(_LLM_PROMPT_TOKENS)), prompt_tokens=_maybe_int(attrs.get(_LLM_PROMPT_TOKENS)),
completion_tokens=_maybe_int(attrs.get(_LLM_COMPLETION_TOKENS)), completion_tokens=_maybe_int(attrs.get(_LLM_COMPLETION_TOKENS)),

View file

@ -13,7 +13,6 @@ from typing import Any
import requests import requests
DEFAULT_PRICING_URL = "https://api.digitalocean.com/v2/gen-ai/models/catalog" DEFAULT_PRICING_URL = "https://api.digitalocean.com/v2/gen-ai/models/catalog"
FETCH_TIMEOUT_SECS = 5.0 FETCH_TIMEOUT_SECS = 5.0

View file

@ -271,7 +271,11 @@ def _recent_table(calls: list[LLMCall], limit: int = 15) -> Table:
recent = list(reversed(calls))[:limit] recent = list(reversed(calls))[:limit]
for c in recent: for c in recent:
status_cell = "ok" if c.status_code and 200 <= c.status_code < 400 else str(c.status_code or "") status_cell = (
"ok"
if c.status_code and 200 <= c.status_code < 400
else str(c.status_code or "")
)
row = [ row = [
c.timestamp.strftime("%H:%M:%S"), c.timestamp.strftime("%H:%M:%S"),
c.model, c.model,

View file

@ -28,7 +28,9 @@ def _mk_attr(key: str, value):
return kv return kv
def _mk_span(attrs: dict, start_ns: int | None = None, span_id_hex: str = "ab") -> MagicMock: def _mk_span(
attrs: dict, start_ns: int | None = None, span_id_hex: str = "ab"
) -> MagicMock:
span = MagicMock() span = MagicMock()
span.attributes = [_mk_attr(k, v) for k, v in attrs.items()] span.attributes = [_mk_attr(k, v) for k, v in attrs.items()]
span.start_time_unix_nano = start_ns or int(time.time() * 1_000_000_000) span.start_time_unix_nano = start_ns or int(time.time() * 1_000_000_000)
@ -84,7 +86,9 @@ def test_pricing_lookup_attaches_cost():
class StubPricing: class StubPricing:
def cost_for_call(self, call): def cost_for_call(self, call):
# Simple: 2 * prompt + 3 * completion, in cents # Simple: 2 * prompt + 3 * completion, in cents
return 0.02 * (call.prompt_tokens or 0) + 0.03 * (call.completion_tokens or 0) return 0.02 * (call.prompt_tokens or 0) + 0.03 * (
call.completion_tokens or 0
)
span = _mk_span( span = _mk_span(
{ {

View file

@ -4,7 +4,17 @@ from planoai.obs.collector import LLMCall
from planoai.obs.render import aggregates, model_rollups, route_hits from planoai.obs.render import aggregates, model_rollups, route_hits
def _call(model: str, ts: datetime, prompt=0, completion=0, cost=None, route=None, session=None, cache_read=0, cache_write=0): def _call(
model: str,
ts: datetime,
prompt=0,
completion=0,
cost=None,
route=None,
session=None,
cache_read=0,
cache_write=0,
):
return LLMCall( return LLMCall(
request_id="r", request_id="r",
timestamp=ts, timestamp=ts,
@ -22,9 +32,30 @@ def _call(model: str, ts: datetime, prompt=0, completion=0, cost=None, route=Non
def test_aggregates_sum_and_session_counts(): def test_aggregates_sum_and_session_counts():
now = datetime.now(tz=timezone.utc).astimezone() now = datetime.now(tz=timezone.utc).astimezone()
calls = [ calls = [
_call("m1", now - timedelta(seconds=50), prompt=10, completion=5, cost=0.001, session="s1"), _call(
_call("m2", now - timedelta(seconds=40), prompt=20, completion=10, cost=0.002, session="s1"), "m1",
_call("m1", now - timedelta(seconds=30), prompt=30, completion=15, cost=0.003, session="s2"), now - timedelta(seconds=50),
prompt=10,
completion=5,
cost=0.001,
session="s1",
),
_call(
"m2",
now - timedelta(seconds=40),
prompt=20,
completion=10,
cost=0.002,
session="s1",
),
_call(
"m1",
now - timedelta(seconds=30),
prompt=30,
completion=15,
cost=0.003,
session="s2",
),
] ]
stats = aggregates(calls) stats = aggregates(calls)
assert stats.count == 3 assert stats.count == 3
@ -38,7 +69,9 @@ def test_aggregates_sum_and_session_counts():
def test_rollups_split_by_model_and_cache(): def test_rollups_split_by_model_and_cache():
now = datetime.now(tz=timezone.utc).astimezone() now = datetime.now(tz=timezone.utc).astimezone()
calls = [ calls = [
_call("m1", now, prompt=10, completion=5, cost=0.001, cache_write=3, cache_read=7), _call(
"m1", now, prompt=10, completion=5, cost=0.001, cache_write=3, cache_read=7
),
_call("m1", now, prompt=20, completion=10, cost=0.002, cache_read=1), _call("m1", now, prompt=20, completion=10, cost=0.002, cache_read=1),
_call("m2", now, prompt=30, completion=15, cost=0.004), _call("m2", now, prompt=30, completion=15, cost=0.004),
] ]