Fixing tests

This commit is contained in:
Cyber MacGeddon 2026-03-09 10:14:01 +00:00
parent 61a100dec9
commit df12467510
6 changed files with 44 additions and 44 deletions

View file

@ -303,7 +303,7 @@ class TestDocumentRagIntegration:
assert "DocumentRag initialized" in log_messages
assert "Constructing prompt..." in log_messages
assert "Computing embeddings..." in log_messages
assert "chunk_ids" in log_messages.lower()
assert "chunks" in log_messages.lower()
assert "Invoking LLM..." in log_messages
assert "Query processing complete" in log_messages

View file

@ -11,7 +11,7 @@ NOTE: This is the first integration test file for GraphRAG (previously had only
import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
from trustgraph.schema import EntityMatch, Term
from trustgraph.schema import EntityMatch, Term, IRI
@pytest.mark.integration
@ -36,9 +36,9 @@ class TestGraphRagIntegration:
"""Mock graph embeddings client that returns realistic entities"""
client = AsyncMock()
client.query.return_value = [
EntityMatch(entity=Term(value="http://trustgraph.ai/e/machine-learning", is_uri=True), score=0.95),
EntityMatch(entity=Term(value="http://trustgraph.ai/e/artificial-intelligence", is_uri=True), score=0.90),
EntityMatch(entity=Term(value="http://trustgraph.ai/e/neural-networks", is_uri=True), score=0.85)
EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/machine-learning"), score=0.95),
EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/artificial-intelligence"), score=0.90),
EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/neural-networks"), score=0.85)
]
return client

View file

@ -8,7 +8,7 @@ response delivery through the complete pipeline.
import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
from trustgraph.schema import EntityMatch, Term
from trustgraph.schema import EntityMatch, Term, IRI
from tests.utils.streaming_assertions import (
assert_streaming_chunks_valid,
assert_rag_streaming_chunks,
@ -34,7 +34,7 @@ class TestGraphRagStreaming:
"""Mock graph embeddings client"""
client = AsyncMock()
client.query.return_value = [
EntityMatch(entity=Term(value="http://trustgraph.ai/e/machine-learning", is_uri=True), score=0.95),
EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/machine-learning"), score=0.95),
]
return client

View file

@ -9,7 +9,7 @@ import pytest
from unittest.mock import AsyncMock, MagicMock, call
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
from trustgraph.schema import EntityMatch, ChunkMatch, Term
from trustgraph.schema import EntityMatch, ChunkMatch, Term, IRI
class TestGraphRagStreamingProtocol:
@ -27,8 +27,8 @@ class TestGraphRagStreamingProtocol:
"""Mock graph embeddings client"""
client = AsyncMock()
client.query.return_value = [
EntityMatch(entity=Term(value="entity1", is_uri=True), score=0.95),
EntityMatch(entity=Term(value="entity2", is_uri=True), score=0.90)
EntityMatch(entity=Term(type=IRI, iri="entity1"), score=0.95),
EntityMatch(entity=Term(type=IRI, iri="entity2"), score=0.90)
]
return client

View file

@ -22,28 +22,28 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
client = DocumentEmbeddingsClient()
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
mock_response.error = None
mock_response.chunk_ids = ["chunk1", "chunk2", "chunk3"]
mock_response.chunks = ["chunk1", "chunk2", "chunk3"]
# Mock the request method
client.request = AsyncMock(return_value=mock_response)
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
vector = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
# Act
result = await client.query(
vectors=vectors,
vector=vector,
limit=10,
user="test_user",
collection="test_collection",
timeout=30
)
# Assert
assert result == ["chunk1", "chunk2", "chunk3"]
client.request.assert_called_once()
call_args = client.request.call_args[0][0]
assert isinstance(call_args, DocumentEmbeddingsRequest)
assert call_args.vectors == vectors
assert call_args.vector == vector
assert call_args.limit == 10
assert call_args.user == "test_user"
assert call_args.collection == "test_collection"
@ -63,7 +63,7 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
# Act & Assert
with pytest.raises(RuntimeError, match="Database connection failed"):
await client.query(
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=5
)
@ -75,13 +75,13 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
client = DocumentEmbeddingsClient()
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
mock_response.error = None
mock_response.chunk_ids = []
mock_response.chunks = []
client.request = AsyncMock(return_value=mock_response)
# Act
result = await client.query(vectors=[[0.1, 0.2, 0.3]])
result = await client.query(vector=[0.1, 0.2, 0.3])
# Assert
assert result == []
@ -93,12 +93,12 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
client = DocumentEmbeddingsClient()
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
mock_response.error = None
mock_response.chunk_ids = ["test_chunk"]
mock_response.chunks = ["test_chunk"]
client.request = AsyncMock(return_value=mock_response)
# Act
result = await client.query(vectors=[[0.1, 0.2, 0.3]])
result = await client.query(vector=[0.1, 0.2, 0.3])
# Assert
client.request.assert_called_once()
@ -115,16 +115,16 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
client = DocumentEmbeddingsClient()
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
mock_response.error = None
mock_response.chunk_ids = ["chunk1"]
mock_response.chunks = ["chunk1"]
client.request = AsyncMock(return_value=mock_response)
# Act
await client.query(
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
timeout=60
)
# Assert
assert client.request.call_args[1]["timeout"] == 60
@ -136,14 +136,14 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
client = DocumentEmbeddingsClient()
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
mock_response.error = None
mock_response.chunk_ids = ["test_chunk"]
mock_response.chunks = ["test_chunk"]
client.request = AsyncMock(return_value=mock_response)
# Act
with patch('trustgraph.base.document_embeddings_client.logger') as mock_logger:
result = await client.query(vectors=[[0.1, 0.2, 0.3]])
result = await client.query(vector=[0.1, 0.2, 0.3])
# Assert
mock_logger.debug.assert_called_once()
assert "Document embeddings response" in str(mock_logger.debug.call_args)

View file

@ -69,24 +69,24 @@ class TestSyncDocumentEmbeddingsClient:
mock_response = MagicMock()
mock_response.chunks = ["chunk1", "chunk2", "chunk3"]
client.call = MagicMock(return_value=mock_response)
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
vector = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
# Act
result = client.request(
vectors=vectors,
vector=vector,
user="test_user",
collection="test_collection",
limit=10,
timeout=300
)
# Assert
assert result == ["chunk1", "chunk2", "chunk3"]
client.call.assert_called_once_with(
user="test_user",
collection="test_collection",
vectors=vectors,
vector=vector,
limit=10,
timeout=300
)