mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-31 02:23:36 +02:00
Fixing tests
This commit is contained in:
parent
61a100dec9
commit
df12467510
6 changed files with 44 additions and 44 deletions
|
|
@ -303,7 +303,7 @@ class TestDocumentRagIntegration:
|
|||
assert "DocumentRag initialized" in log_messages
|
||||
assert "Constructing prompt..." in log_messages
|
||||
assert "Computing embeddings..." in log_messages
|
||||
assert "chunk_ids" in log_messages.lower()
|
||||
assert "chunks" in log_messages.lower()
|
||||
assert "Invoking LLM..." in log_messages
|
||||
assert "Query processing complete" in log_messages
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ NOTE: This is the first integration test file for GraphRAG (previously had only
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
|
||||
from trustgraph.schema import EntityMatch, Term
|
||||
from trustgraph.schema import EntityMatch, Term, IRI
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
|
|
@ -36,9 +36,9 @@ class TestGraphRagIntegration:
|
|||
"""Mock graph embeddings client that returns realistic entities"""
|
||||
client = AsyncMock()
|
||||
client.query.return_value = [
|
||||
EntityMatch(entity=Term(value="http://trustgraph.ai/e/machine-learning", is_uri=True), score=0.95),
|
||||
EntityMatch(entity=Term(value="http://trustgraph.ai/e/artificial-intelligence", is_uri=True), score=0.90),
|
||||
EntityMatch(entity=Term(value="http://trustgraph.ai/e/neural-networks", is_uri=True), score=0.85)
|
||||
EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/machine-learning"), score=0.95),
|
||||
EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/artificial-intelligence"), score=0.90),
|
||||
EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/neural-networks"), score=0.85)
|
||||
]
|
||||
return client
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ response delivery through the complete pipeline.
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
|
||||
from trustgraph.schema import EntityMatch, Term
|
||||
from trustgraph.schema import EntityMatch, Term, IRI
|
||||
from tests.utils.streaming_assertions import (
|
||||
assert_streaming_chunks_valid,
|
||||
assert_rag_streaming_chunks,
|
||||
|
|
@ -34,7 +34,7 @@ class TestGraphRagStreaming:
|
|||
"""Mock graph embeddings client"""
|
||||
client = AsyncMock()
|
||||
client.query.return_value = [
|
||||
EntityMatch(entity=Term(value="http://trustgraph.ai/e/machine-learning", is_uri=True), score=0.95),
|
||||
EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/machine-learning"), score=0.95),
|
||||
]
|
||||
return client
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import pytest
|
|||
from unittest.mock import AsyncMock, MagicMock, call
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||
from trustgraph.schema import EntityMatch, ChunkMatch, Term
|
||||
from trustgraph.schema import EntityMatch, ChunkMatch, Term, IRI
|
||||
|
||||
|
||||
class TestGraphRagStreamingProtocol:
|
||||
|
|
@ -27,8 +27,8 @@ class TestGraphRagStreamingProtocol:
|
|||
"""Mock graph embeddings client"""
|
||||
client = AsyncMock()
|
||||
client.query.return_value = [
|
||||
EntityMatch(entity=Term(value="entity1", is_uri=True), score=0.95),
|
||||
EntityMatch(entity=Term(value="entity2", is_uri=True), score=0.90)
|
||||
EntityMatch(entity=Term(type=IRI, iri="entity1"), score=0.95),
|
||||
EntityMatch(entity=Term(type=IRI, iri="entity2"), score=0.90)
|
||||
]
|
||||
return client
|
||||
|
||||
|
|
|
|||
|
|
@ -22,28 +22,28 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
|
|||
client = DocumentEmbeddingsClient()
|
||||
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
|
||||
mock_response.error = None
|
||||
mock_response.chunk_ids = ["chunk1", "chunk2", "chunk3"]
|
||||
|
||||
mock_response.chunks = ["chunk1", "chunk2", "chunk3"]
|
||||
|
||||
# Mock the request method
|
||||
client.request = AsyncMock(return_value=mock_response)
|
||||
|
||||
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
|
||||
|
||||
vector = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
|
||||
|
||||
# Act
|
||||
result = await client.query(
|
||||
vectors=vectors,
|
||||
vector=vector,
|
||||
limit=10,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
timeout=30
|
||||
)
|
||||
|
||||
|
||||
# Assert
|
||||
assert result == ["chunk1", "chunk2", "chunk3"]
|
||||
client.request.assert_called_once()
|
||||
call_args = client.request.call_args[0][0]
|
||||
assert isinstance(call_args, DocumentEmbeddingsRequest)
|
||||
assert call_args.vectors == vectors
|
||||
assert call_args.vector == vector
|
||||
assert call_args.limit == 10
|
||||
assert call_args.user == "test_user"
|
||||
assert call_args.collection == "test_collection"
|
||||
|
|
@ -63,7 +63,7 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
|
|||
# Act & Assert
|
||||
with pytest.raises(RuntimeError, match="Database connection failed"):
|
||||
await client.query(
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
)
|
||||
|
||||
|
|
@ -75,13 +75,13 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
|
|||
client = DocumentEmbeddingsClient()
|
||||
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
|
||||
mock_response.error = None
|
||||
mock_response.chunk_ids = []
|
||||
|
||||
mock_response.chunks = []
|
||||
|
||||
client.request = AsyncMock(return_value=mock_response)
|
||||
|
||||
|
||||
# Act
|
||||
result = await client.query(vectors=[[0.1, 0.2, 0.3]])
|
||||
|
||||
result = await client.query(vector=[0.1, 0.2, 0.3])
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
|
|
@ -93,12 +93,12 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
|
|||
client = DocumentEmbeddingsClient()
|
||||
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
|
||||
mock_response.error = None
|
||||
mock_response.chunk_ids = ["test_chunk"]
|
||||
|
||||
mock_response.chunks = ["test_chunk"]
|
||||
|
||||
client.request = AsyncMock(return_value=mock_response)
|
||||
|
||||
|
||||
# Act
|
||||
result = await client.query(vectors=[[0.1, 0.2, 0.3]])
|
||||
result = await client.query(vector=[0.1, 0.2, 0.3])
|
||||
|
||||
# Assert
|
||||
client.request.assert_called_once()
|
||||
|
|
@ -115,16 +115,16 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
|
|||
client = DocumentEmbeddingsClient()
|
||||
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
|
||||
mock_response.error = None
|
||||
mock_response.chunk_ids = ["chunk1"]
|
||||
|
||||
mock_response.chunks = ["chunk1"]
|
||||
|
||||
client.request = AsyncMock(return_value=mock_response)
|
||||
|
||||
|
||||
# Act
|
||||
await client.query(
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
timeout=60
|
||||
)
|
||||
|
||||
|
||||
# Assert
|
||||
assert client.request.call_args[1]["timeout"] == 60
|
||||
|
||||
|
|
@ -136,14 +136,14 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
|
|||
client = DocumentEmbeddingsClient()
|
||||
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
|
||||
mock_response.error = None
|
||||
mock_response.chunk_ids = ["test_chunk"]
|
||||
|
||||
mock_response.chunks = ["test_chunk"]
|
||||
|
||||
client.request = AsyncMock(return_value=mock_response)
|
||||
|
||||
|
||||
# Act
|
||||
with patch('trustgraph.base.document_embeddings_client.logger') as mock_logger:
|
||||
result = await client.query(vectors=[[0.1, 0.2, 0.3]])
|
||||
|
||||
result = await client.query(vector=[0.1, 0.2, 0.3])
|
||||
|
||||
# Assert
|
||||
mock_logger.debug.assert_called_once()
|
||||
assert "Document embeddings response" in str(mock_logger.debug.call_args)
|
||||
|
|
|
|||
|
|
@ -69,24 +69,24 @@ class TestSyncDocumentEmbeddingsClient:
|
|||
mock_response = MagicMock()
|
||||
mock_response.chunks = ["chunk1", "chunk2", "chunk3"]
|
||||
client.call = MagicMock(return_value=mock_response)
|
||||
|
||||
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
|
||||
|
||||
vector = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
|
||||
|
||||
# Act
|
||||
result = client.request(
|
||||
vectors=vectors,
|
||||
vector=vector,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
limit=10,
|
||||
timeout=300
|
||||
)
|
||||
|
||||
|
||||
# Assert
|
||||
assert result == ["chunk1", "chunk2", "chunk3"]
|
||||
client.call.assert_called_once_with(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
vectors=vectors,
|
||||
vector=vector,
|
||||
limit=10,
|
||||
timeout=300
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue