nomyo-router/test/test_db.py

197 lines
8.3 KiB
Python

"""Direct unit tests for db.TokenDatabase — no router/app dependency."""
from datetime import datetime, timezone
import pytest
from db import TokenDatabase
@pytest.fixture
async def db(tmp_path):
inst = TokenDatabase(str(tmp_path / "tokens.db"))
await inst.init_db()
yield inst
await inst.close()
class TestInit:
async def test_init_creates_tables(self, db):
# Re-init must be idempotent
await db.init_db()
# Insert + read confirms tables exist
await db.update_token_counts("http://ep", "m", 1, 2)
rows = [r async for r in db.load_token_counts()]
assert len(rows) == 1
async def test_creates_parent_directory(self, tmp_path):
nested = tmp_path / "nested" / "subdir" / "x.db"
inst = TokenDatabase(str(nested))
await inst.init_db()
try:
assert nested.parent.exists()
finally:
await inst.close()
class TestUpdateTokenCounts:
async def test_insert_then_update_aggregates(self, db):
await db.update_token_counts("http://ep", "m1", 10, 20)
await db.update_token_counts("http://ep", "m1", 5, 7)
rows = [r async for r in db.load_token_counts()]
assert len(rows) == 1
r = rows[0]
assert r["endpoint"] == "http://ep"
assert r["model"] == "m1"
assert r["input_tokens"] == 15
assert r["output_tokens"] == 27
assert r["total_tokens"] == 42
async def test_independent_endpoint_model_pairs(self, db):
await db.update_token_counts("http://ep1", "m1", 1, 1)
await db.update_token_counts("http://ep1", "m2", 2, 2)
await db.update_token_counts("http://ep2", "m1", 3, 3)
rows = [r async for r in db.load_token_counts()]
assert len(rows) == 3
totals = {(r["endpoint"], r["model"]): r["total_tokens"] for r in rows}
assert totals == {
("http://ep1", "m1"): 2,
("http://ep1", "m2"): 4,
("http://ep2", "m1"): 6,
}
class TestBatchedCounts:
async def test_update_batched_counts(self, db):
counts = {
"http://a": {"m": (4, 6)},
"http://b": {"m": (1, 1), "n": (10, 0)},
}
await db.update_batched_counts(counts)
rows = [r async for r in db.load_token_counts()]
totals = {(r["endpoint"], r["model"]): r["total_tokens"] for r in rows}
assert totals == {
("http://a", "m"): 10,
("http://b", "m"): 2,
("http://b", "n"): 10,
}
async def test_empty_batch_is_noop(self, db):
await db.update_batched_counts({})
rows = [r async for r in db.load_token_counts()]
assert rows == []
class TestTimeSeries:
async def test_add_time_series_entry(self, db):
# The aggregate FK requires the (endpoint,model) row to exist first
await db.update_token_counts("http://ep", "m", 0, 0)
await db.add_time_series_entry("http://ep", "m", 3, 4)
await db.add_time_series_entry("http://ep", "m", 1, 1)
rows = [r async for r in db.get_latest_time_series(limit=10)]
assert len(rows) == 2
# Newest-first ordering; both timestamps are within the same minute,
# so just check totals are present and well-formed
for r in rows:
assert r["endpoint"] == "http://ep"
assert r["model"] == "m"
assert r["total_tokens"] == r["input_tokens"] + r["output_tokens"]
async def test_add_batched_time_series(self, db):
await db.update_token_counts("http://ep", "m", 0, 0)
now = int(datetime.now(tz=timezone.utc).timestamp())
entries = [
{"endpoint": "http://ep", "model": "m", "input_tokens": 1,
"output_tokens": 2, "total_tokens": 3, "timestamp": now - 60},
{"endpoint": "http://ep", "model": "m", "input_tokens": 4,
"output_tokens": 5, "total_tokens": 9, "timestamp": now},
]
await db.add_batched_time_series(entries)
rows = [r async for r in db.get_latest_time_series(limit=10)]
assert len(rows) == 2
assert rows[0]["timestamp"] >= rows[1]["timestamp"]
async def test_get_time_series_for_model_filters(self, db):
await db.update_token_counts("http://ep", "m1", 0, 0)
await db.update_token_counts("http://ep", "m2", 0, 0)
now = int(datetime.now(tz=timezone.utc).timestamp())
await db.add_batched_time_series([
{"endpoint": "http://ep", "model": "m1", "input_tokens": 1,
"output_tokens": 1, "total_tokens": 2, "timestamp": now},
{"endpoint": "http://ep", "model": "m2", "input_tokens": 9,
"output_tokens": 9, "total_tokens": 18, "timestamp": now},
])
rows = [r async for r in db.get_time_series_for_model("m1")]
assert len(rows) == 1
assert rows[0]["total_tokens"] == 2
async def test_endpoint_distribution_for_model(self, db):
await db.update_token_counts("http://a", "m", 0, 0)
await db.update_token_counts("http://b", "m", 0, 0)
now = int(datetime.now(tz=timezone.utc).timestamp())
await db.add_batched_time_series([
{"endpoint": "http://a", "model": "m", "input_tokens": 1,
"output_tokens": 1, "total_tokens": 2, "timestamp": now},
{"endpoint": "http://a", "model": "m", "input_tokens": 1,
"output_tokens": 1, "total_tokens": 2, "timestamp": now},
{"endpoint": "http://b", "model": "m", "input_tokens": 5,
"output_tokens": 5, "total_tokens": 10, "timestamp": now},
])
dist = await db.get_endpoint_distribution_for_model("m")
assert dist == {"http://a": 4, "http://b": 10}
class TestGetTokenCountsForModel:
async def test_aggregates_across_endpoints(self, db):
await db.update_token_counts("http://a", "m", 1, 2)
await db.update_token_counts("http://b", "m", 3, 4)
result = await db.get_token_counts_for_model("m")
assert result is not None
assert result["endpoint"] == "aggregated"
assert result["model"] == "m"
assert result["input_tokens"] == 4
assert result["output_tokens"] == 6
assert result["total_tokens"] == 10
async def test_unknown_model_returns_zero_aggregate(self, db):
# SUM(...) WHERE no-match returns one row with NULLs — exposed as zeros
result = await db.get_token_counts_for_model("nope")
assert result is not None
assert result["input_tokens"] in (0, None)
class TestAggregateTimeSeriesOlderThan:
async def test_aggregates_old_entries_by_day(self, db):
await db.update_token_counts("http://ep", "m", 0, 0)
now = int(datetime.now(tz=timezone.utc).timestamp())
old = now - (40 * 86400) # 40 days ago
await db.add_batched_time_series([
{"endpoint": "http://ep", "model": "m", "input_tokens": 1,
"output_tokens": 1, "total_tokens": 2, "timestamp": old},
{"endpoint": "http://ep", "model": "m", "input_tokens": 3,
"output_tokens": 3, "total_tokens": 6, "timestamp": old + 60},
{"endpoint": "http://ep", "model": "m", "input_tokens": 99,
"output_tokens": 99, "total_tokens": 198, "timestamp": now},
])
n = await db.aggregate_time_series_older_than(30, trim_old=False)
assert n == 1 # one (endpoint, model, day) group rolled up
async def test_invalid_days_falls_back_to_30(self, db):
# Just ensure it doesn't blow up with a bogus value
n = await db.aggregate_time_series_older_than(0)
assert n == 0
async def test_trim_old_removes_aggregated_rows(self, db):
await db.update_token_counts("http://ep", "m", 0, 0)
now = int(datetime.now(tz=timezone.utc).timestamp())
old = now - (40 * 86400)
await db.add_batched_time_series([
{"endpoint": "http://ep", "model": "m", "input_tokens": 1,
"output_tokens": 1, "total_tokens": 2, "timestamp": old},
{"endpoint": "http://ep", "model": "m", "input_tokens": 99,
"output_tokens": 99, "total_tokens": 198, "timestamp": now},
])
await db.aggregate_time_series_older_than(30, trim_old=True)
remaining = [r async for r in db.get_latest_time_series(limit=10)]
# Only the recent (within-cutoff) row should remain
assert len(remaining) == 1
assert remaining[0]["total_tokens"] == 198