trustgraph/tests/unit/test_embeddings/test_embeddings_client.py
cybermaggedon 29b4300808
Updated test suite for explainability & provenance (#696)
* Provenance tests

* Embeddings tests

* Test librarian

* Test triples stream

* Test concurrency

* Entity centric graph writes

* Agent tool service tests

* Structured data tests

* RDF tests

* Addition LLM tests

* Reliability tests
2026-03-13 14:27:42 +00:00

109 lines
3.7 KiB
Python

"""
Tests for EmbeddingsClient — the client interface for batch embeddings.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.base.embeddings_client import EmbeddingsClient
from trustgraph.schema import EmbeddingsRequest, EmbeddingsResponse, Error
class TestEmbeddingsClient:
@pytest.mark.asyncio
async def test_embed_sends_request_and_returns_vectors(self):
"""embed() should send an EmbeddingsRequest and return vectors."""
client = EmbeddingsClient.__new__(EmbeddingsClient)
client.request = AsyncMock(return_value=EmbeddingsResponse(
error=None,
vectors=[[0.1, 0.2], [0.3, 0.4]],
))
result = await client.embed(texts=["hello", "world"])
assert result == [[0.1, 0.2], [0.3, 0.4]]
client.request.assert_called_once()
req = client.request.call_args[0][0]
assert isinstance(req, EmbeddingsRequest)
assert req.texts == ["hello", "world"]
@pytest.mark.asyncio
async def test_embed_single_text(self):
"""embed() should work with a single text."""
client = EmbeddingsClient.__new__(EmbeddingsClient)
client.request = AsyncMock(return_value=EmbeddingsResponse(
error=None,
vectors=[[1.0, 2.0, 3.0]],
))
result = await client.embed(texts=["single"])
assert result == [[1.0, 2.0, 3.0]]
@pytest.mark.asyncio
async def test_embed_raises_on_error_response(self):
"""embed() should raise RuntimeError when response contains an error."""
client = EmbeddingsClient.__new__(EmbeddingsClient)
client.request = AsyncMock(return_value=EmbeddingsResponse(
error=Error(type="embeddings-error", message="model not found"),
vectors=[],
))
with pytest.raises(RuntimeError, match="model not found"):
await client.embed(texts=["test"])
@pytest.mark.asyncio
async def test_embed_passes_timeout(self):
"""embed() should pass timeout to the underlying request."""
client = EmbeddingsClient.__new__(EmbeddingsClient)
client.request = AsyncMock(return_value=EmbeddingsResponse(
error=None, vectors=[[0.0]],
))
await client.embed(texts=["test"], timeout=60)
_, kwargs = client.request.call_args
assert kwargs["timeout"] == 60
@pytest.mark.asyncio
async def test_embed_default_timeout(self):
"""embed() should use 300s default timeout."""
client = EmbeddingsClient.__new__(EmbeddingsClient)
client.request = AsyncMock(return_value=EmbeddingsResponse(
error=None, vectors=[[0.0]],
))
await client.embed(texts=["test"])
_, kwargs = client.request.call_args
assert kwargs["timeout"] == 300
@pytest.mark.asyncio
async def test_embed_empty_texts(self):
"""embed() with empty list should still make the request."""
client = EmbeddingsClient.__new__(EmbeddingsClient)
client.request = AsyncMock(return_value=EmbeddingsResponse(
error=None, vectors=[],
))
result = await client.embed(texts=[])
assert result == []
@pytest.mark.asyncio
async def test_embed_large_batch(self):
"""embed() should handle large batches."""
client = EmbeddingsClient.__new__(EmbeddingsClient)
n = 100
vectors = [[float(i)] for i in range(n)]
client.request = AsyncMock(return_value=EmbeddingsResponse(
error=None, vectors=vectors,
))
texts = [f"text {i}" for i in range(n)]
result = await client.embed(texts=texts)
assert len(result) == n
req = client.request.call_args[0][0]
assert len(req.texts) == n