"""Direct unit tests for db.TokenDatabase — no router/app dependency.""" from datetime import datetime, timezone import pytest from db import TokenDatabase @pytest.fixture async def db(tmp_path): inst = TokenDatabase(str(tmp_path / "tokens.db")) await inst.init_db() yield inst await inst.close() class TestInit: async def test_init_creates_tables(self, db): # Re-init must be idempotent await db.init_db() # Insert + read confirms tables exist await db.update_token_counts("http://ep", "m", 1, 2) rows = [r async for r in db.load_token_counts()] assert len(rows) == 1 async def test_creates_parent_directory(self, tmp_path): nested = tmp_path / "nested" / "subdir" / "x.db" inst = TokenDatabase(str(nested)) await inst.init_db() try: assert nested.parent.exists() finally: await inst.close() class TestUpdateTokenCounts: async def test_insert_then_update_aggregates(self, db): await db.update_token_counts("http://ep", "m1", 10, 20) await db.update_token_counts("http://ep", "m1", 5, 7) rows = [r async for r in db.load_token_counts()] assert len(rows) == 1 r = rows[0] assert r["endpoint"] == "http://ep" assert r["model"] == "m1" assert r["input_tokens"] == 15 assert r["output_tokens"] == 27 assert r["total_tokens"] == 42 async def test_independent_endpoint_model_pairs(self, db): await db.update_token_counts("http://ep1", "m1", 1, 1) await db.update_token_counts("http://ep1", "m2", 2, 2) await db.update_token_counts("http://ep2", "m1", 3, 3) rows = [r async for r in db.load_token_counts()] assert len(rows) == 3 totals = {(r["endpoint"], r["model"]): r["total_tokens"] for r in rows} assert totals == { ("http://ep1", "m1"): 2, ("http://ep1", "m2"): 4, ("http://ep2", "m1"): 6, } class TestBatchedCounts: async def test_update_batched_counts(self, db): counts = { "http://a": {"m": (4, 6)}, "http://b": {"m": (1, 1), "n": (10, 0)}, } await db.update_batched_counts(counts) rows = [r async for r in db.load_token_counts()] totals = {(r["endpoint"], r["model"]): r["total_tokens"] for r in rows} assert totals == { ("http://a", "m"): 10, ("http://b", "m"): 2, ("http://b", "n"): 10, } async def test_empty_batch_is_noop(self, db): await db.update_batched_counts({}) rows = [r async for r in db.load_token_counts()] assert rows == [] class TestTimeSeries: async def test_add_time_series_entry(self, db): # The aggregate FK requires the (endpoint,model) row to exist first await db.update_token_counts("http://ep", "m", 0, 0) await db.add_time_series_entry("http://ep", "m", 3, 4) await db.add_time_series_entry("http://ep", "m", 1, 1) rows = [r async for r in db.get_latest_time_series(limit=10)] assert len(rows) == 2 # Newest-first ordering; both timestamps are within the same minute, # so just check totals are present and well-formed for r in rows: assert r["endpoint"] == "http://ep" assert r["model"] == "m" assert r["total_tokens"] == r["input_tokens"] + r["output_tokens"] async def test_add_batched_time_series(self, db): await db.update_token_counts("http://ep", "m", 0, 0) now = int(datetime.now(tz=timezone.utc).timestamp()) entries = [ {"endpoint": "http://ep", "model": "m", "input_tokens": 1, "output_tokens": 2, "total_tokens": 3, "timestamp": now - 60}, {"endpoint": "http://ep", "model": "m", "input_tokens": 4, "output_tokens": 5, "total_tokens": 9, "timestamp": now}, ] await db.add_batched_time_series(entries) rows = [r async for r in db.get_latest_time_series(limit=10)] assert len(rows) == 2 assert rows[0]["timestamp"] >= rows[1]["timestamp"] async def test_get_time_series_for_model_filters(self, db): await db.update_token_counts("http://ep", "m1", 0, 0) await db.update_token_counts("http://ep", "m2", 0, 0) now = int(datetime.now(tz=timezone.utc).timestamp()) await db.add_batched_time_series([ {"endpoint": "http://ep", "model": "m1", "input_tokens": 1, "output_tokens": 1, "total_tokens": 2, "timestamp": now}, {"endpoint": "http://ep", "model": "m2", "input_tokens": 9, "output_tokens": 9, "total_tokens": 18, "timestamp": now}, ]) rows = [r async for r in db.get_time_series_for_model("m1")] assert len(rows) == 1 assert rows[0]["total_tokens"] == 2 async def test_endpoint_distribution_for_model(self, db): await db.update_token_counts("http://a", "m", 0, 0) await db.update_token_counts("http://b", "m", 0, 0) now = int(datetime.now(tz=timezone.utc).timestamp()) await db.add_batched_time_series([ {"endpoint": "http://a", "model": "m", "input_tokens": 1, "output_tokens": 1, "total_tokens": 2, "timestamp": now}, {"endpoint": "http://a", "model": "m", "input_tokens": 1, "output_tokens": 1, "total_tokens": 2, "timestamp": now}, {"endpoint": "http://b", "model": "m", "input_tokens": 5, "output_tokens": 5, "total_tokens": 10, "timestamp": now}, ]) dist = await db.get_endpoint_distribution_for_model("m") assert dist == {"http://a": 4, "http://b": 10} class TestGetTokenCountsForModel: async def test_aggregates_across_endpoints(self, db): await db.update_token_counts("http://a", "m", 1, 2) await db.update_token_counts("http://b", "m", 3, 4) result = await db.get_token_counts_for_model("m") assert result is not None assert result["endpoint"] == "aggregated" assert result["model"] == "m" assert result["input_tokens"] == 4 assert result["output_tokens"] == 6 assert result["total_tokens"] == 10 async def test_unknown_model_returns_zero_aggregate(self, db): # SUM(...) WHERE no-match returns one row with NULLs — exposed as zeros result = await db.get_token_counts_for_model("nope") assert result is not None assert result["input_tokens"] in (0, None) class TestAggregateTimeSeriesOlderThan: async def test_aggregates_old_entries_by_day(self, db): await db.update_token_counts("http://ep", "m", 0, 0) now = int(datetime.now(tz=timezone.utc).timestamp()) old = now - (40 * 86400) # 40 days ago await db.add_batched_time_series([ {"endpoint": "http://ep", "model": "m", "input_tokens": 1, "output_tokens": 1, "total_tokens": 2, "timestamp": old}, {"endpoint": "http://ep", "model": "m", "input_tokens": 3, "output_tokens": 3, "total_tokens": 6, "timestamp": old + 60}, {"endpoint": "http://ep", "model": "m", "input_tokens": 99, "output_tokens": 99, "total_tokens": 198, "timestamp": now}, ]) n = await db.aggregate_time_series_older_than(30, trim_old=False) assert n == 1 # one (endpoint, model, day) group rolled up async def test_invalid_days_falls_back_to_30(self, db): # Just ensure it doesn't blow up with a bogus value n = await db.aggregate_time_series_older_than(0) assert n == 0 async def test_trim_old_removes_aggregated_rows(self, db): await db.update_token_counts("http://ep", "m", 0, 0) now = int(datetime.now(tz=timezone.utc).timestamp()) old = now - (40 * 86400) await db.add_batched_time_series([ {"endpoint": "http://ep", "model": "m", "input_tokens": 1, "output_tokens": 1, "total_tokens": 2, "timestamp": old}, {"endpoint": "http://ep", "model": "m", "input_tokens": 99, "output_tokens": 99, "total_tokens": 198, "timestamp": now}, ]) await db.aggregate_time_series_older_than(30, trim_old=True) remaining = [r async for r in db.get_latest_time_series(limit=10)] # Only the recent (within-cutoff) row should remain assert len(remaining) == 1 assert remaining[0]["total_tokens"] == 198