"""Plan 05-10 — AsyncWriteQueue unit tests (OPS-16, M-03). Coalesce window + batched flush against a LanceDB-shaped async table. Tests use a MockAsyncTable that records each ``add(batch)`` call so we can assert batch sizes / call counts without pulling a live LanceDB connection for every test. Contracts covered: W1 — single enqueue + await resolves within the coalesce window and produces exactly one tbl.add() call containing one record. W2 — two enqueues within land in ONE tbl.add call (batch size 2). This is the core "coalesce" invariant. W3 — max_batch+1 enqueues produce TWO tbl.add calls (max_batch split). W4 — stop() drains an in-flight batch: pending enqueues complete before stop() returns; further enqueue attempts raise RuntimeError. W5 — back-pressure: with max_queue_size=N, the (N+1)th enqueue awaits until a flush drains the queue (never unbounded growth). W6 — on_flushed callback fires once per record, in batch order. W7 — tbl.add() raises -> every pending future in that batch resolves with the same exception; queue stays running so subsequent enqueues still work. """ from __future__ import annotations import asyncio import pytest from iai_mcp.write_queue import AsyncWriteQueue # ------------------------------------------------------------------ mock table class MockAsyncTable: """Minimal stand-in for lancedb AsyncTable. ``add(batch)`` is an awaitable that records every call so tests can assert call_count, batch sizes, and ordering. Supports an injected exception via ``raise_on_add`` to test W7. """ def __init__(self, *, raise_on_add: BaseException | None = None) -> None: self.calls: list[list] = [] self.raise_on_add = raise_on_add # Optional delay to simulate LanceDB flush latency (used by W4). self._delay_s: float = 0.0 async def add(self, batch) -> None: # Copy so later mutations by the queue don't change what we recorded. self.calls.append(list(batch)) if self._delay_s: await asyncio.sleep(self._delay_s) if self.raise_on_add is not None: raise self.raise_on_add # ------------------------------------------------------------------ W1 def test_single_enqueue_flushes_within_coalesce_window(): table = MockAsyncTable() async def run() -> None: q = AsyncWriteQueue(table, coalesce_ms=50, max_batch=128) await q.start() try: fut = await q.enqueue({"id": "r1"}) await asyncio.wait_for(fut, timeout=0.5) finally: await q.stop() asyncio.run(run()) assert len(table.calls) == 1 assert len(table.calls[0]) == 1 assert table.calls[0][0]["id"] == "r1" # ------------------------------------------------------------------ W2 def test_coalesce_window_batches_concurrent_enqueues(): """Two enqueues inside the same coalesce window -> one tbl.add(size=2).""" table = MockAsyncTable() async def run() -> None: q = AsyncWriteQueue(table, coalesce_ms=80, max_batch=128) await q.start() try: fut1 = await q.enqueue({"id": "r1"}) fut2 = await q.enqueue({"id": "r2"}) await asyncio.wait_for(asyncio.gather(fut1, fut2), timeout=0.5) finally: await q.stop() asyncio.run(run()) # Exactly ONE add() call carrying both records, in enqueue order. assert len(table.calls) == 1, f"expected one batched add, got {len(table.calls)}" ids = [r["id"] for r in table.calls[0]] assert ids == ["r1", "r2"] # ------------------------------------------------------------------ W3 def test_max_batch_splits_into_two_flushes(): """max_batch+1 enqueues -> two add() calls (one full, one size=1).""" table = MockAsyncTable() async def run() -> None: q = AsyncWriteQueue(table, coalesce_ms=50, max_batch=4) await q.start() try: futs = [await q.enqueue({"id": f"r{i}"}) for i in range(5)] await asyncio.wait_for(asyncio.gather(*futs), timeout=1.0) finally: await q.stop() asyncio.run(run()) batch_sizes = [len(c) for c in table.calls] # Either [4,1] (strict split) or [5] would violate max_batch. We assert # the total is 5 and at least one batch is <= max_batch=4. assert sum(batch_sizes) == 5 assert len(table.calls) >= 2 assert all(sz <= 4 for sz in batch_sizes) # ------------------------------------------------------------------ W4 def test_stop_drains_pending_records(): """stop() awaits the in-flight batch so enqueued records are durable.""" table = MockAsyncTable() async def run() -> None: q = AsyncWriteQueue(table, coalesce_ms=30, max_batch=128) await q.start() fut = await q.enqueue({"id": "r1"}) # Don't await fut here -- let stop() drain it. await q.stop() assert fut.done(), "stop() must drain pending futures" # Enqueuing after stop should fail. with pytest.raises(RuntimeError): await q.enqueue({"id": "r2"}) asyncio.run(run()) assert sum(len(c) for c in table.calls) == 1 # ------------------------------------------------------------------ W5 def test_backpressure_awaits_when_queue_full(): """max_queue_size=2 -> third enqueue awaits until a flush frees a slot.""" table = MockAsyncTable() # Slow down the flush so back-pressure is observable. table._delay_s = 0.05 async def run() -> int: q = AsyncWriteQueue( table, coalesce_ms=30, max_batch=2, max_queue_size=2, ) await q.start() try: # First two are accepted immediately (fit inside the buffer). f1 = await q.enqueue({"id": "r1"}) f2 = await q.enqueue({"id": "r2"}) # Third one MUST await at least one flush before accepting. t0 = asyncio.get_event_loop().time() f3 = await q.enqueue({"id": "r3"}) waited = asyncio.get_event_loop().time() - t0 await asyncio.gather(f1, f2, f3) return 1 if waited >= 0.01 else 0 finally: await q.stop() waited_flag = asyncio.run(run()) assert waited_flag == 1, "back-pressure enqueue must await at least one flush" # ------------------------------------------------------------------ W6 def test_on_flushed_fires_per_record_in_batch_order(): table = MockAsyncTable() flushed: list[dict] = [] def on_flushed(batch): flushed.extend(batch) async def run() -> None: q = AsyncWriteQueue( table, coalesce_ms=40, max_batch=128, on_flushed=on_flushed, ) await q.start() try: futs = [await q.enqueue({"id": f"r{i}"}) for i in range(3)] await asyncio.gather(*futs) finally: await q.stop() asyncio.run(run()) assert [r["id"] for r in flushed] == ["r0", "r1", "r2"] # ------------------------------------------------------------------ W7 def test_flush_exception_propagates_to_all_futures_in_batch(): err = RuntimeError("lancedb boom") table = MockAsyncTable(raise_on_add=err) async def run() -> tuple[list, MockAsyncTable]: q = AsyncWriteQueue(table, coalesce_ms=30, max_batch=128) await q.start() try: f1 = await q.enqueue({"id": "r1"}) f2 = await q.enqueue({"id": "r2"}) results = [] for f in (f1, f2): try: await f results.append(None) except RuntimeError as exc: results.append(exc) # Queue must stay running; clear the error and try again. table.raise_on_add = None f3 = await q.enqueue({"id": "r3"}) await f3 return results, table finally: await q.stop() results, _ = asyncio.run(run()) assert len(results) == 2 assert all(isinstance(r, RuntimeError) and str(r) == "lancedb boom" for r in results)