197 lines
8.3 KiB
Python
197 lines
8.3 KiB
Python
"""Direct unit tests for db.TokenDatabase — no router/app dependency."""
|
|
from datetime import datetime, timezone
|
|
|
|
import pytest
|
|
|
|
from db import TokenDatabase
|
|
|
|
|
|
@pytest.fixture
|
|
async def db(tmp_path):
|
|
inst = TokenDatabase(str(tmp_path / "tokens.db"))
|
|
await inst.init_db()
|
|
yield inst
|
|
await inst.close()
|
|
|
|
|
|
class TestInit:
|
|
async def test_init_creates_tables(self, db):
|
|
# Re-init must be idempotent
|
|
await db.init_db()
|
|
# Insert + read confirms tables exist
|
|
await db.update_token_counts("http://ep", "m", 1, 2)
|
|
rows = [r async for r in db.load_token_counts()]
|
|
assert len(rows) == 1
|
|
|
|
async def test_creates_parent_directory(self, tmp_path):
|
|
nested = tmp_path / "nested" / "subdir" / "x.db"
|
|
inst = TokenDatabase(str(nested))
|
|
await inst.init_db()
|
|
try:
|
|
assert nested.parent.exists()
|
|
finally:
|
|
await inst.close()
|
|
|
|
|
|
class TestUpdateTokenCounts:
|
|
async def test_insert_then_update_aggregates(self, db):
|
|
await db.update_token_counts("http://ep", "m1", 10, 20)
|
|
await db.update_token_counts("http://ep", "m1", 5, 7)
|
|
rows = [r async for r in db.load_token_counts()]
|
|
assert len(rows) == 1
|
|
r = rows[0]
|
|
assert r["endpoint"] == "http://ep"
|
|
assert r["model"] == "m1"
|
|
assert r["input_tokens"] == 15
|
|
assert r["output_tokens"] == 27
|
|
assert r["total_tokens"] == 42
|
|
|
|
async def test_independent_endpoint_model_pairs(self, db):
|
|
await db.update_token_counts("http://ep1", "m1", 1, 1)
|
|
await db.update_token_counts("http://ep1", "m2", 2, 2)
|
|
await db.update_token_counts("http://ep2", "m1", 3, 3)
|
|
rows = [r async for r in db.load_token_counts()]
|
|
assert len(rows) == 3
|
|
totals = {(r["endpoint"], r["model"]): r["total_tokens"] for r in rows}
|
|
assert totals == {
|
|
("http://ep1", "m1"): 2,
|
|
("http://ep1", "m2"): 4,
|
|
("http://ep2", "m1"): 6,
|
|
}
|
|
|
|
|
|
class TestBatchedCounts:
|
|
async def test_update_batched_counts(self, db):
|
|
counts = {
|
|
"http://a": {"m": (4, 6)},
|
|
"http://b": {"m": (1, 1), "n": (10, 0)},
|
|
}
|
|
await db.update_batched_counts(counts)
|
|
rows = [r async for r in db.load_token_counts()]
|
|
totals = {(r["endpoint"], r["model"]): r["total_tokens"] for r in rows}
|
|
assert totals == {
|
|
("http://a", "m"): 10,
|
|
("http://b", "m"): 2,
|
|
("http://b", "n"): 10,
|
|
}
|
|
|
|
async def test_empty_batch_is_noop(self, db):
|
|
await db.update_batched_counts({})
|
|
rows = [r async for r in db.load_token_counts()]
|
|
assert rows == []
|
|
|
|
|
|
class TestTimeSeries:
|
|
async def test_add_time_series_entry(self, db):
|
|
# The aggregate FK requires the (endpoint,model) row to exist first
|
|
await db.update_token_counts("http://ep", "m", 0, 0)
|
|
await db.add_time_series_entry("http://ep", "m", 3, 4)
|
|
await db.add_time_series_entry("http://ep", "m", 1, 1)
|
|
rows = [r async for r in db.get_latest_time_series(limit=10)]
|
|
assert len(rows) == 2
|
|
# Newest-first ordering; both timestamps are within the same minute,
|
|
# so just check totals are present and well-formed
|
|
for r in rows:
|
|
assert r["endpoint"] == "http://ep"
|
|
assert r["model"] == "m"
|
|
assert r["total_tokens"] == r["input_tokens"] + r["output_tokens"]
|
|
|
|
async def test_add_batched_time_series(self, db):
|
|
await db.update_token_counts("http://ep", "m", 0, 0)
|
|
now = int(datetime.now(tz=timezone.utc).timestamp())
|
|
entries = [
|
|
{"endpoint": "http://ep", "model": "m", "input_tokens": 1,
|
|
"output_tokens": 2, "total_tokens": 3, "timestamp": now - 60},
|
|
{"endpoint": "http://ep", "model": "m", "input_tokens": 4,
|
|
"output_tokens": 5, "total_tokens": 9, "timestamp": now},
|
|
]
|
|
await db.add_batched_time_series(entries)
|
|
rows = [r async for r in db.get_latest_time_series(limit=10)]
|
|
assert len(rows) == 2
|
|
assert rows[0]["timestamp"] >= rows[1]["timestamp"]
|
|
|
|
async def test_get_time_series_for_model_filters(self, db):
|
|
await db.update_token_counts("http://ep", "m1", 0, 0)
|
|
await db.update_token_counts("http://ep", "m2", 0, 0)
|
|
now = int(datetime.now(tz=timezone.utc).timestamp())
|
|
await db.add_batched_time_series([
|
|
{"endpoint": "http://ep", "model": "m1", "input_tokens": 1,
|
|
"output_tokens": 1, "total_tokens": 2, "timestamp": now},
|
|
{"endpoint": "http://ep", "model": "m2", "input_tokens": 9,
|
|
"output_tokens": 9, "total_tokens": 18, "timestamp": now},
|
|
])
|
|
rows = [r async for r in db.get_time_series_for_model("m1")]
|
|
assert len(rows) == 1
|
|
assert rows[0]["total_tokens"] == 2
|
|
|
|
async def test_endpoint_distribution_for_model(self, db):
|
|
await db.update_token_counts("http://a", "m", 0, 0)
|
|
await db.update_token_counts("http://b", "m", 0, 0)
|
|
now = int(datetime.now(tz=timezone.utc).timestamp())
|
|
await db.add_batched_time_series([
|
|
{"endpoint": "http://a", "model": "m", "input_tokens": 1,
|
|
"output_tokens": 1, "total_tokens": 2, "timestamp": now},
|
|
{"endpoint": "http://a", "model": "m", "input_tokens": 1,
|
|
"output_tokens": 1, "total_tokens": 2, "timestamp": now},
|
|
{"endpoint": "http://b", "model": "m", "input_tokens": 5,
|
|
"output_tokens": 5, "total_tokens": 10, "timestamp": now},
|
|
])
|
|
dist = await db.get_endpoint_distribution_for_model("m")
|
|
assert dist == {"http://a": 4, "http://b": 10}
|
|
|
|
|
|
class TestGetTokenCountsForModel:
|
|
async def test_aggregates_across_endpoints(self, db):
|
|
await db.update_token_counts("http://a", "m", 1, 2)
|
|
await db.update_token_counts("http://b", "m", 3, 4)
|
|
result = await db.get_token_counts_for_model("m")
|
|
assert result is not None
|
|
assert result["endpoint"] == "aggregated"
|
|
assert result["model"] == "m"
|
|
assert result["input_tokens"] == 4
|
|
assert result["output_tokens"] == 6
|
|
assert result["total_tokens"] == 10
|
|
|
|
async def test_unknown_model_returns_zero_aggregate(self, db):
|
|
# SUM(...) WHERE no-match returns one row with NULLs — exposed as zeros
|
|
result = await db.get_token_counts_for_model("nope")
|
|
assert result is not None
|
|
assert result["input_tokens"] in (0, None)
|
|
|
|
|
|
class TestAggregateTimeSeriesOlderThan:
|
|
async def test_aggregates_old_entries_by_day(self, db):
|
|
await db.update_token_counts("http://ep", "m", 0, 0)
|
|
now = int(datetime.now(tz=timezone.utc).timestamp())
|
|
old = now - (40 * 86400) # 40 days ago
|
|
await db.add_batched_time_series([
|
|
{"endpoint": "http://ep", "model": "m", "input_tokens": 1,
|
|
"output_tokens": 1, "total_tokens": 2, "timestamp": old},
|
|
{"endpoint": "http://ep", "model": "m", "input_tokens": 3,
|
|
"output_tokens": 3, "total_tokens": 6, "timestamp": old + 60},
|
|
{"endpoint": "http://ep", "model": "m", "input_tokens": 99,
|
|
"output_tokens": 99, "total_tokens": 198, "timestamp": now},
|
|
])
|
|
n = await db.aggregate_time_series_older_than(30, trim_old=False)
|
|
assert n == 1 # one (endpoint, model, day) group rolled up
|
|
|
|
async def test_invalid_days_falls_back_to_30(self, db):
|
|
# Just ensure it doesn't blow up with a bogus value
|
|
n = await db.aggregate_time_series_older_than(0)
|
|
assert n == 0
|
|
|
|
async def test_trim_old_removes_aggregated_rows(self, db):
|
|
await db.update_token_counts("http://ep", "m", 0, 0)
|
|
now = int(datetime.now(tz=timezone.utc).timestamp())
|
|
old = now - (40 * 86400)
|
|
await db.add_batched_time_series([
|
|
{"endpoint": "http://ep", "model": "m", "input_tokens": 1,
|
|
"output_tokens": 1, "total_tokens": 2, "timestamp": old},
|
|
{"endpoint": "http://ep", "model": "m", "input_tokens": 99,
|
|
"output_tokens": 99, "total_tokens": 198, "timestamp": now},
|
|
])
|
|
await db.aggregate_time_series_older_than(30, trim_old=True)
|
|
remaining = [r async for r in db.get_latest_time_series(limit=10)]
|
|
# Only the recent (within-cutoff) row should remain
|
|
assert len(remaining) == 1
|
|
assert remaining[0]["total_tokens"] == 198
|