mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-30 10:56:23 +02:00
Updated test suite for explainability & provenance (#696)
* Provenance tests * Embeddings tests * Test librarian * Test triples stream * Test concurrency * Entity centric graph writes * Agent tool service tests * Structured data tests * RDF tests * Addition LLM tests * Reliability tests
This commit is contained in:
parent
e6623fc915
commit
29b4300808
36 changed files with 8799 additions and 0 deletions
164
tests/unit/test_embeddings/test_document_embeddings_processor.py
Normal file
164
tests/unit/test_embeddings/test_document_embeddings_processor.py
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
"""
|
||||
Tests for document embeddings processor — single-chunk embedding via batch API.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from trustgraph.embeddings.document_embeddings.embeddings import Processor
|
||||
from trustgraph.schema import (
|
||||
Chunk, DocumentEmbeddings, ChunkEmbeddings,
|
||||
EmbeddingsRequest, EmbeddingsResponse, Metadata,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def processor():
|
||||
return Processor(
|
||||
taskgroup=AsyncMock(),
|
||||
id="test-doc-embeddings",
|
||||
)
|
||||
|
||||
|
||||
def _make_chunk_message(chunk_text="Hello world", doc_id="doc-1",
|
||||
user="test", collection="default"):
|
||||
metadata = Metadata(id=doc_id, user=user, collection=collection)
|
||||
value = Chunk(metadata=metadata, chunk=chunk_text, document_id=doc_id)
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = value
|
||||
return msg
|
||||
|
||||
|
||||
class TestDocumentEmbeddingsProcessor:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sends_single_text_as_list(self, processor):
|
||||
"""Document embeddings should wrap single chunk in a list for the API."""
|
||||
msg = _make_chunk_message("test chunk text")
|
||||
|
||||
mock_request = AsyncMock(return_value=EmbeddingsResponse(
|
||||
error=None, vectors=[[0.1, 0.2, 0.3]]
|
||||
))
|
||||
mock_output = AsyncMock()
|
||||
|
||||
def flow(name):
|
||||
if name == "embeddings-request":
|
||||
return MagicMock(request=mock_request)
|
||||
elif name == "output":
|
||||
return mock_output
|
||||
return MagicMock()
|
||||
|
||||
await processor.on_message(msg, MagicMock(), flow)
|
||||
|
||||
# Should send EmbeddingsRequest with texts=[chunk]
|
||||
mock_request.assert_called_once()
|
||||
req = mock_request.call_args[0][0]
|
||||
assert isinstance(req, EmbeddingsRequest)
|
||||
assert req.texts == ["test chunk text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extracts_first_vector(self, processor):
|
||||
"""Should use vectors[0] from the response."""
|
||||
msg = _make_chunk_message("chunk")
|
||||
|
||||
mock_request = AsyncMock(return_value=EmbeddingsResponse(
|
||||
error=None, vectors=[[1.0, 2.0, 3.0]]
|
||||
))
|
||||
mock_output = AsyncMock()
|
||||
|
||||
def flow(name):
|
||||
if name == "embeddings-request":
|
||||
return MagicMock(request=mock_request)
|
||||
elif name == "output":
|
||||
return mock_output
|
||||
return MagicMock()
|
||||
|
||||
await processor.on_message(msg, MagicMock(), flow)
|
||||
|
||||
result = mock_output.send.call_args[0][0]
|
||||
assert isinstance(result, DocumentEmbeddings)
|
||||
assert len(result.chunks) == 1
|
||||
assert result.chunks[0].vector == [1.0, 2.0, 3.0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_vectors_response(self, processor):
|
||||
"""Should handle empty vectors response gracefully."""
|
||||
msg = _make_chunk_message("chunk")
|
||||
|
||||
mock_request = AsyncMock(return_value=EmbeddingsResponse(
|
||||
error=None, vectors=[]
|
||||
))
|
||||
mock_output = AsyncMock()
|
||||
|
||||
def flow(name):
|
||||
if name == "embeddings-request":
|
||||
return MagicMock(request=mock_request)
|
||||
elif name == "output":
|
||||
return mock_output
|
||||
return MagicMock()
|
||||
|
||||
await processor.on_message(msg, MagicMock(), flow)
|
||||
|
||||
result = mock_output.send.call_args[0][0]
|
||||
assert result.chunks[0].vector == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunk_id_is_document_id(self, processor):
|
||||
"""ChunkEmbeddings should use document_id as chunk_id."""
|
||||
msg = _make_chunk_message(doc_id="my-doc-42")
|
||||
|
||||
mock_request = AsyncMock(return_value=EmbeddingsResponse(
|
||||
error=None, vectors=[[0.0]]
|
||||
))
|
||||
mock_output = AsyncMock()
|
||||
|
||||
def flow(name):
|
||||
if name == "embeddings-request":
|
||||
return MagicMock(request=mock_request)
|
||||
elif name == "output":
|
||||
return mock_output
|
||||
return MagicMock()
|
||||
|
||||
await processor.on_message(msg, MagicMock(), flow)
|
||||
|
||||
result = mock_output.send.call_args[0][0]
|
||||
assert result.chunks[0].chunk_id == "my-doc-42"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metadata_preserved(self, processor):
|
||||
"""Output should carry the original metadata."""
|
||||
msg = _make_chunk_message(user="alice", collection="reports", doc_id="d1")
|
||||
|
||||
mock_request = AsyncMock(return_value=EmbeddingsResponse(
|
||||
error=None, vectors=[[0.0]]
|
||||
))
|
||||
mock_output = AsyncMock()
|
||||
|
||||
def flow(name):
|
||||
if name == "embeddings-request":
|
||||
return MagicMock(request=mock_request)
|
||||
elif name == "output":
|
||||
return mock_output
|
||||
return MagicMock()
|
||||
|
||||
await processor.on_message(msg, MagicMock(), flow)
|
||||
|
||||
result = mock_output.send.call_args[0][0]
|
||||
assert result.metadata.user == "alice"
|
||||
assert result.metadata.collection == "reports"
|
||||
assert result.metadata.id == "d1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_propagates(self, processor):
|
||||
"""Embedding errors should propagate for retry."""
|
||||
msg = _make_chunk_message()
|
||||
|
||||
mock_request = AsyncMock(side_effect=RuntimeError("service down"))
|
||||
|
||||
def flow(name):
|
||||
if name == "embeddings-request":
|
||||
return MagicMock(request=mock_request)
|
||||
return MagicMock()
|
||||
|
||||
with pytest.raises(RuntimeError, match="service down"):
|
||||
await processor.on_message(msg, MagicMock(), flow)
|
||||
109
tests/unit/test_embeddings/test_embeddings_client.py
Normal file
109
tests/unit/test_embeddings/test_embeddings_client.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
"""
|
||||
Tests for EmbeddingsClient — the client interface for batch embeddings.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from trustgraph.base.embeddings_client import EmbeddingsClient
|
||||
from trustgraph.schema import EmbeddingsRequest, EmbeddingsResponse, Error
|
||||
|
||||
|
||||
class TestEmbeddingsClient:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embed_sends_request_and_returns_vectors(self):
|
||||
"""embed() should send an EmbeddingsRequest and return vectors."""
|
||||
client = EmbeddingsClient.__new__(EmbeddingsClient)
|
||||
client.request = AsyncMock(return_value=EmbeddingsResponse(
|
||||
error=None,
|
||||
vectors=[[0.1, 0.2], [0.3, 0.4]],
|
||||
))
|
||||
|
||||
result = await client.embed(texts=["hello", "world"])
|
||||
|
||||
assert result == [[0.1, 0.2], [0.3, 0.4]]
|
||||
client.request.assert_called_once()
|
||||
req = client.request.call_args[0][0]
|
||||
assert isinstance(req, EmbeddingsRequest)
|
||||
assert req.texts == ["hello", "world"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embed_single_text(self):
|
||||
"""embed() should work with a single text."""
|
||||
client = EmbeddingsClient.__new__(EmbeddingsClient)
|
||||
client.request = AsyncMock(return_value=EmbeddingsResponse(
|
||||
error=None,
|
||||
vectors=[[1.0, 2.0, 3.0]],
|
||||
))
|
||||
|
||||
result = await client.embed(texts=["single"])
|
||||
|
||||
assert result == [[1.0, 2.0, 3.0]]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embed_raises_on_error_response(self):
|
||||
"""embed() should raise RuntimeError when response contains an error."""
|
||||
client = EmbeddingsClient.__new__(EmbeddingsClient)
|
||||
client.request = AsyncMock(return_value=EmbeddingsResponse(
|
||||
error=Error(type="embeddings-error", message="model not found"),
|
||||
vectors=[],
|
||||
))
|
||||
|
||||
with pytest.raises(RuntimeError, match="model not found"):
|
||||
await client.embed(texts=["test"])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embed_passes_timeout(self):
|
||||
"""embed() should pass timeout to the underlying request."""
|
||||
client = EmbeddingsClient.__new__(EmbeddingsClient)
|
||||
client.request = AsyncMock(return_value=EmbeddingsResponse(
|
||||
error=None, vectors=[[0.0]],
|
||||
))
|
||||
|
||||
await client.embed(texts=["test"], timeout=60)
|
||||
|
||||
_, kwargs = client.request.call_args
|
||||
assert kwargs["timeout"] == 60
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embed_default_timeout(self):
|
||||
"""embed() should use 300s default timeout."""
|
||||
client = EmbeddingsClient.__new__(EmbeddingsClient)
|
||||
client.request = AsyncMock(return_value=EmbeddingsResponse(
|
||||
error=None, vectors=[[0.0]],
|
||||
))
|
||||
|
||||
await client.embed(texts=["test"])
|
||||
|
||||
_, kwargs = client.request.call_args
|
||||
assert kwargs["timeout"] == 300
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embed_empty_texts(self):
|
||||
"""embed() with empty list should still make the request."""
|
||||
client = EmbeddingsClient.__new__(EmbeddingsClient)
|
||||
client.request = AsyncMock(return_value=EmbeddingsResponse(
|
||||
error=None, vectors=[],
|
||||
))
|
||||
|
||||
result = await client.embed(texts=[])
|
||||
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embed_large_batch(self):
|
||||
"""embed() should handle large batches."""
|
||||
client = EmbeddingsClient.__new__(EmbeddingsClient)
|
||||
n = 100
|
||||
vectors = [[float(i)] for i in range(n)]
|
||||
client.request = AsyncMock(return_value=EmbeddingsResponse(
|
||||
error=None, vectors=vectors,
|
||||
))
|
||||
|
||||
texts = [f"text {i}" for i in range(n)]
|
||||
result = await client.embed(texts=texts)
|
||||
|
||||
assert len(result) == n
|
||||
req = client.request.call_args[0][0]
|
||||
assert len(req.texts) == n
|
||||
135
tests/unit/test_embeddings/test_embeddings_service_request.py
Normal file
135
tests/unit/test_embeddings/test_embeddings_service_request.py
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
"""
|
||||
Tests for EmbeddingsService.on_request — the request handler that dispatches
|
||||
to on_embeddings and sends responses.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from trustgraph.base import EmbeddingsService
|
||||
from trustgraph.schema import EmbeddingsRequest, EmbeddingsResponse, Error
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
|
||||
|
||||
class StubEmbeddingsService(EmbeddingsService):
|
||||
"""Minimal concrete implementation for testing on_request."""
|
||||
|
||||
def __init__(self, embed_result=None, embed_error=None):
|
||||
# Skip super().__init__ to avoid taskgroup/registration
|
||||
self.embed_result = embed_result or [[0.1, 0.2]]
|
||||
self.embed_error = embed_error
|
||||
|
||||
async def on_embeddings(self, texts, model=None):
|
||||
if self.embed_error:
|
||||
raise self.embed_error
|
||||
return self.embed_result
|
||||
|
||||
|
||||
def _make_msg(texts, msg_id="req-1"):
|
||||
request = EmbeddingsRequest(texts=texts)
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = request
|
||||
msg.properties.return_value = {"id": msg_id}
|
||||
return msg
|
||||
|
||||
|
||||
def _make_flow(model="test-model"):
|
||||
mock_response_producer = AsyncMock()
|
||||
mock_flow = MagicMock()
|
||||
|
||||
def flow_callable(name):
|
||||
if name == "model":
|
||||
return model
|
||||
if name == "response":
|
||||
return mock_response_producer
|
||||
return MagicMock()
|
||||
|
||||
flow_callable.producer = {"response": mock_response_producer}
|
||||
return flow_callable, mock_response_producer
|
||||
|
||||
|
||||
class TestEmbeddingsServiceOnRequest:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_request(self):
|
||||
"""on_request should call on_embeddings and send response."""
|
||||
service = StubEmbeddingsService(embed_result=[[0.1, 0.2], [0.3, 0.4]])
|
||||
msg = _make_msg(["hello", "world"], msg_id="r1")
|
||||
flow, mock_response = _make_flow(model="my-model")
|
||||
|
||||
await service.on_request(msg, MagicMock(), flow)
|
||||
|
||||
mock_response.send.assert_called_once()
|
||||
resp = mock_response.send.call_args[0][0]
|
||||
assert isinstance(resp, EmbeddingsResponse)
|
||||
assert resp.error is None
|
||||
assert resp.vectors == [[0.1, 0.2], [0.3, 0.4]]
|
||||
|
||||
# Check id is passed through
|
||||
props = mock_response.send.call_args[1]["properties"]
|
||||
assert props["id"] == "r1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_passes_model_from_flow(self):
|
||||
"""on_request should pass model parameter from flow to on_embeddings."""
|
||||
calls = []
|
||||
|
||||
class TrackingService(EmbeddingsService):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
async def on_embeddings(self, texts, model=None):
|
||||
calls.append({"texts": texts, "model": model})
|
||||
return [[0.0]]
|
||||
|
||||
service = TrackingService()
|
||||
msg = _make_msg(["test"])
|
||||
flow, _ = _make_flow(model="custom-model-v2")
|
||||
|
||||
await service.on_request(msg, MagicMock(), flow)
|
||||
|
||||
assert len(calls) == 1
|
||||
assert calls[0]["model"] == "custom-model-v2"
|
||||
assert calls[0]["texts"] == ["test"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_sends_error_response(self):
|
||||
"""Non-rate-limit errors should send an error response."""
|
||||
service = StubEmbeddingsService(
|
||||
embed_error=ValueError("dimension mismatch")
|
||||
)
|
||||
msg = _make_msg(["test"], msg_id="r2")
|
||||
flow, mock_response = _make_flow()
|
||||
|
||||
await service.on_request(msg, MagicMock(), flow)
|
||||
|
||||
mock_response.send.assert_called_once()
|
||||
resp = mock_response.send.call_args[0][0]
|
||||
assert resp.error is not None
|
||||
assert resp.error.type == "embeddings-error"
|
||||
assert "dimension mismatch" in resp.error.message
|
||||
assert resp.vectors == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_propagates(self):
|
||||
"""TooManyRequests should propagate (not caught as error response)."""
|
||||
service = StubEmbeddingsService(
|
||||
embed_error=TooManyRequests("rate limited")
|
||||
)
|
||||
msg = _make_msg(["test"])
|
||||
flow, _ = _make_flow()
|
||||
|
||||
with pytest.raises(TooManyRequests):
|
||||
await service.on_request(msg, MagicMock(), flow)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_id_preserved(self):
|
||||
"""The request message id should be forwarded in the response properties."""
|
||||
service = StubEmbeddingsService()
|
||||
msg = _make_msg(["test"], msg_id="unique-id-42")
|
||||
flow, mock_response = _make_flow()
|
||||
|
||||
await service.on_request(msg, MagicMock(), flow)
|
||||
|
||||
props = mock_response.send.call_args[1]["properties"]
|
||||
assert props["id"] == "unique-id-42"
|
||||
233
tests/unit/test_embeddings/test_graph_embeddings_processor.py
Normal file
233
tests/unit/test_embeddings/test_graph_embeddings_processor.py
Normal file
|
|
@ -0,0 +1,233 @@
|
|||
"""
|
||||
Tests for graph embeddings processor — batch embedding of entity contexts.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from trustgraph.embeddings.graph_embeddings.embeddings import Processor
|
||||
from trustgraph.schema import (
|
||||
EntityContexts, EntityEmbeddings, GraphEmbeddings,
|
||||
Term, IRI, Metadata,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def processor():
|
||||
return Processor(
|
||||
taskgroup=AsyncMock(),
|
||||
id="test-graph-embeddings",
|
||||
batch_size=3,
|
||||
)
|
||||
|
||||
|
||||
def _make_entity_context(name, context, chunk_id="chunk-1"):
|
||||
"""Create an entity context for testing."""
|
||||
entity = Term(type=IRI, iri=f"urn:entity:{name}")
|
||||
return MagicMock(entity=entity, context=context, chunk_id=chunk_id)
|
||||
|
||||
|
||||
def _make_message(entities, doc_id="doc-1", user="test", collection="default"):
|
||||
metadata = Metadata(id=doc_id, user=user, collection=collection)
|
||||
value = EntityContexts(metadata=metadata, entities=entities)
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = value
|
||||
return msg
|
||||
|
||||
|
||||
class TestGraphEmbeddingsInit:
|
||||
|
||||
def test_default_batch_size(self):
|
||||
p = Processor(taskgroup=AsyncMock(), id="test")
|
||||
assert p.batch_size == 5
|
||||
|
||||
def test_custom_batch_size(self):
|
||||
p = Processor(taskgroup=AsyncMock(), id="test", batch_size=20)
|
||||
assert p.batch_size == 20
|
||||
|
||||
|
||||
class TestGraphEmbeddingsBatchProcessing:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_batch_call_for_all_entities(self, processor):
|
||||
"""All entity contexts should be embedded in a single API call."""
|
||||
entities = [
|
||||
_make_entity_context("Alice", "Alice is a person"),
|
||||
_make_entity_context("Bob", "Bob is a developer"),
|
||||
_make_entity_context("Acme", "Acme is a company"),
|
||||
]
|
||||
msg = _make_message(entities)
|
||||
|
||||
mock_embed = AsyncMock(return_value=[
|
||||
[0.1, 0.2], [0.3, 0.4], [0.5, 0.6],
|
||||
])
|
||||
mock_output = AsyncMock()
|
||||
|
||||
def flow(name):
|
||||
if name == "embeddings-request":
|
||||
return MagicMock(embed=mock_embed)
|
||||
elif name == "output":
|
||||
return mock_output
|
||||
return MagicMock()
|
||||
|
||||
await processor.on_message(msg, MagicMock(), flow)
|
||||
|
||||
# Single batch call with all three texts
|
||||
mock_embed.assert_called_once_with(
|
||||
texts=["Alice is a person", "Bob is a developer", "Acme is a company"]
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vectors_paired_with_correct_entities(self, processor):
|
||||
"""Each vector should be paired with its corresponding entity."""
|
||||
entities = [
|
||||
_make_entity_context("Alice", "ctx-A", chunk_id="c1"),
|
||||
_make_entity_context("Bob", "ctx-B", chunk_id="c2"),
|
||||
]
|
||||
msg = _make_message(entities)
|
||||
|
||||
vectors = [[1.0, 2.0], [3.0, 4.0]]
|
||||
mock_embed = AsyncMock(return_value=vectors)
|
||||
mock_output = AsyncMock()
|
||||
|
||||
def flow(name):
|
||||
if name == "embeddings-request":
|
||||
return MagicMock(embed=mock_embed)
|
||||
elif name == "output":
|
||||
return mock_output
|
||||
return MagicMock()
|
||||
|
||||
await processor.on_message(msg, MagicMock(), flow)
|
||||
|
||||
# With batch_size=3, all 2 entities fit in one output message
|
||||
mock_output.send.assert_called_once()
|
||||
result = mock_output.send.call_args[0][0]
|
||||
assert isinstance(result, GraphEmbeddings)
|
||||
assert len(result.entities) == 2
|
||||
assert result.entities[0].vector == [1.0, 2.0]
|
||||
assert result.entities[0].entity.iri == "urn:entity:Alice"
|
||||
assert result.entities[0].chunk_id == "c1"
|
||||
assert result.entities[1].vector == [3.0, 4.0]
|
||||
assert result.entities[1].entity.iri == "urn:entity:Bob"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_output_batching(self, processor):
|
||||
"""Output should be split into batches of batch_size."""
|
||||
# batch_size=3, 7 entities -> 3 output messages (3+3+1)
|
||||
entities = [
|
||||
_make_entity_context(f"E{i}", f"context {i}")
|
||||
for i in range(7)
|
||||
]
|
||||
msg = _make_message(entities)
|
||||
|
||||
vectors = [[float(i)] for i in range(7)]
|
||||
mock_embed = AsyncMock(return_value=vectors)
|
||||
mock_output = AsyncMock()
|
||||
|
||||
def flow(name):
|
||||
if name == "embeddings-request":
|
||||
return MagicMock(embed=mock_embed)
|
||||
elif name == "output":
|
||||
return mock_output
|
||||
return MagicMock()
|
||||
|
||||
await processor.on_message(msg, MagicMock(), flow)
|
||||
|
||||
assert mock_output.send.call_count == 3
|
||||
# First batch has 3 entities
|
||||
batch1 = mock_output.send.call_args_list[0][0][0]
|
||||
assert len(batch1.entities) == 3
|
||||
# Second batch has 3 entities
|
||||
batch2 = mock_output.send.call_args_list[1][0][0]
|
||||
assert len(batch2.entities) == 3
|
||||
# Third batch has 1 entity
|
||||
batch3 = mock_output.send.call_args_list[2][0][0]
|
||||
assert len(batch3.entities) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_output_batches_preserve_metadata(self, processor):
|
||||
"""Each output batch should carry the original metadata."""
|
||||
entities = [
|
||||
_make_entity_context(f"E{i}", f"ctx {i}")
|
||||
for i in range(5)
|
||||
]
|
||||
msg = _make_message(entities, doc_id="doc-42", user="alice", collection="main")
|
||||
|
||||
mock_embed = AsyncMock(return_value=[[0.0]] * 5)
|
||||
mock_output = AsyncMock()
|
||||
|
||||
def flow(name):
|
||||
if name == "embeddings-request":
|
||||
return MagicMock(embed=mock_embed)
|
||||
elif name == "output":
|
||||
return mock_output
|
||||
return MagicMock()
|
||||
|
||||
await processor.on_message(msg, MagicMock(), flow)
|
||||
|
||||
for call in mock_output.send.call_args_list:
|
||||
result = call[0][0]
|
||||
assert result.metadata.id == "doc-42"
|
||||
assert result.metadata.user == "alice"
|
||||
assert result.metadata.collection == "main"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_entity(self, processor):
|
||||
"""Single entity should work with one embed call and one output."""
|
||||
entities = [_make_entity_context("Solo", "solo context")]
|
||||
msg = _make_message(entities)
|
||||
|
||||
mock_embed = AsyncMock(return_value=[[1.0, 2.0, 3.0]])
|
||||
mock_output = AsyncMock()
|
||||
|
||||
def flow(name):
|
||||
if name == "embeddings-request":
|
||||
return MagicMock(embed=mock_embed)
|
||||
elif name == "output":
|
||||
return mock_output
|
||||
return MagicMock()
|
||||
|
||||
await processor.on_message(msg, MagicMock(), flow)
|
||||
|
||||
mock_embed.assert_called_once_with(texts=["solo context"])
|
||||
mock_output.send.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embed_error_propagates(self, processor):
|
||||
"""Embedding service errors should propagate for retry."""
|
||||
entities = [_make_entity_context("E", "ctx")]
|
||||
msg = _make_message(entities)
|
||||
|
||||
mock_embed = AsyncMock(side_effect=RuntimeError("embedding failed"))
|
||||
|
||||
def flow(name):
|
||||
if name == "embeddings-request":
|
||||
return MagicMock(embed=mock_embed)
|
||||
return MagicMock()
|
||||
|
||||
with pytest.raises(RuntimeError, match="embedding failed"):
|
||||
await processor.on_message(msg, MagicMock(), flow)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exact_batch_size(self, processor):
|
||||
"""When entity count equals batch_size, exactly one output message."""
|
||||
entities = [
|
||||
_make_entity_context(f"E{i}", f"ctx {i}")
|
||||
for i in range(3) # batch_size=3
|
||||
]
|
||||
msg = _make_message(entities)
|
||||
|
||||
mock_embed = AsyncMock(return_value=[[0.0]] * 3)
|
||||
mock_output = AsyncMock()
|
||||
|
||||
def flow(name):
|
||||
if name == "embeddings-request":
|
||||
return MagicMock(embed=mock_embed)
|
||||
elif name == "output":
|
||||
return mock_output
|
||||
return MagicMock()
|
||||
|
||||
await processor.on_message(msg, MagicMock(), flow)
|
||||
|
||||
mock_output.send.assert_called_once()
|
||||
assert len(mock_output.send.call_args[0][0].entities) == 3
|
||||
Loading…
Add table
Add a link
Reference in a new issue