diff --git a/tests/integration/test_document_rag_integration.py b/tests/integration/test_document_rag_integration.py index 1c4f5fe2..e9df05cf 100644 --- a/tests/integration/test_document_rag_integration.py +++ b/tests/integration/test_document_rag_integration.py @@ -303,7 +303,7 @@ class TestDocumentRagIntegration: assert "DocumentRag initialized" in log_messages assert "Constructing prompt..." in log_messages assert "Computing embeddings..." in log_messages - assert "chunk_ids" in log_messages.lower() + assert "chunks" in log_messages.lower() assert "Invoking LLM..." in log_messages assert "Query processing complete" in log_messages diff --git a/tests/integration/test_graph_rag_integration.py b/tests/integration/test_graph_rag_integration.py index 260b93d5..7a15edb8 100644 --- a/tests/integration/test_graph_rag_integration.py +++ b/tests/integration/test_graph_rag_integration.py @@ -11,7 +11,7 @@ NOTE: This is the first integration test file for GraphRAG (previously had only import pytest from unittest.mock import AsyncMock, MagicMock from trustgraph.retrieval.graph_rag.graph_rag import GraphRag -from trustgraph.schema import EntityMatch, Term +from trustgraph.schema import EntityMatch, Term, IRI @pytest.mark.integration @@ -36,9 +36,9 @@ class TestGraphRagIntegration: """Mock graph embeddings client that returns realistic entities""" client = AsyncMock() client.query.return_value = [ - EntityMatch(entity=Term(value="http://trustgraph.ai/e/machine-learning", is_uri=True), score=0.95), - EntityMatch(entity=Term(value="http://trustgraph.ai/e/artificial-intelligence", is_uri=True), score=0.90), - EntityMatch(entity=Term(value="http://trustgraph.ai/e/neural-networks", is_uri=True), score=0.85) + EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/machine-learning"), score=0.95), + EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/artificial-intelligence"), score=0.90), + EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/neural-networks"), score=0.85) ] return client diff --git a/tests/integration/test_graph_rag_streaming_integration.py b/tests/integration/test_graph_rag_streaming_integration.py index d1eb0099..99880510 100644 --- a/tests/integration/test_graph_rag_streaming_integration.py +++ b/tests/integration/test_graph_rag_streaming_integration.py @@ -8,7 +8,7 @@ response delivery through the complete pipeline. import pytest from unittest.mock import AsyncMock, MagicMock from trustgraph.retrieval.graph_rag.graph_rag import GraphRag -from trustgraph.schema import EntityMatch, Term +from trustgraph.schema import EntityMatch, Term, IRI from tests.utils.streaming_assertions import ( assert_streaming_chunks_valid, assert_rag_streaming_chunks, @@ -34,7 +34,7 @@ class TestGraphRagStreaming: """Mock graph embeddings client""" client = AsyncMock() client.query.return_value = [ - EntityMatch(entity=Term(value="http://trustgraph.ai/e/machine-learning", is_uri=True), score=0.95), + EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/machine-learning"), score=0.95), ] return client diff --git a/tests/integration/test_rag_streaming_protocol.py b/tests/integration/test_rag_streaming_protocol.py index 3c484f27..4fa93afd 100644 --- a/tests/integration/test_rag_streaming_protocol.py +++ b/tests/integration/test_rag_streaming_protocol.py @@ -9,7 +9,7 @@ import pytest from unittest.mock import AsyncMock, MagicMock, call from trustgraph.retrieval.graph_rag.graph_rag import GraphRag from trustgraph.retrieval.document_rag.document_rag import DocumentRag -from trustgraph.schema import EntityMatch, ChunkMatch, Term +from trustgraph.schema import EntityMatch, ChunkMatch, Term, IRI class TestGraphRagStreamingProtocol: @@ -27,8 +27,8 @@ class TestGraphRagStreamingProtocol: """Mock graph embeddings client""" client = AsyncMock() client.query.return_value = [ - EntityMatch(entity=Term(value="entity1", is_uri=True), score=0.95), - EntityMatch(entity=Term(value="entity2", is_uri=True), score=0.90) + EntityMatch(entity=Term(type=IRI, iri="entity1"), score=0.95), + EntityMatch(entity=Term(type=IRI, iri="entity2"), score=0.90) ] return client diff --git a/tests/unit/test_base/test_document_embeddings_client.py b/tests/unit/test_base/test_document_embeddings_client.py index 81d4a98e..705f2bd1 100644 --- a/tests/unit/test_base/test_document_embeddings_client.py +++ b/tests/unit/test_base/test_document_embeddings_client.py @@ -22,28 +22,28 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase): client = DocumentEmbeddingsClient() mock_response = MagicMock(spec=DocumentEmbeddingsResponse) mock_response.error = None - mock_response.chunk_ids = ["chunk1", "chunk2", "chunk3"] - + mock_response.chunks = ["chunk1", "chunk2", "chunk3"] + # Mock the request method client.request = AsyncMock(return_value=mock_response) - - vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] - + + vector = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6] + # Act result = await client.query( - vectors=vectors, + vector=vector, limit=10, user="test_user", collection="test_collection", timeout=30 ) - + # Assert assert result == ["chunk1", "chunk2", "chunk3"] client.request.assert_called_once() call_args = client.request.call_args[0][0] assert isinstance(call_args, DocumentEmbeddingsRequest) - assert call_args.vectors == vectors + assert call_args.vector == vector assert call_args.limit == 10 assert call_args.user == "test_user" assert call_args.collection == "test_collection" @@ -63,7 +63,7 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase): # Act & Assert with pytest.raises(RuntimeError, match="Database connection failed"): await client.query( - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=5 ) @@ -75,13 +75,13 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase): client = DocumentEmbeddingsClient() mock_response = MagicMock(spec=DocumentEmbeddingsResponse) mock_response.error = None - mock_response.chunk_ids = [] - + mock_response.chunks = [] + client.request = AsyncMock(return_value=mock_response) - + # Act - result = await client.query(vectors=[[0.1, 0.2, 0.3]]) - + result = await client.query(vector=[0.1, 0.2, 0.3]) + # Assert assert result == [] @@ -93,12 +93,12 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase): client = DocumentEmbeddingsClient() mock_response = MagicMock(spec=DocumentEmbeddingsResponse) mock_response.error = None - mock_response.chunk_ids = ["test_chunk"] - + mock_response.chunks = ["test_chunk"] + client.request = AsyncMock(return_value=mock_response) - + # Act - result = await client.query(vectors=[[0.1, 0.2, 0.3]]) + result = await client.query(vector=[0.1, 0.2, 0.3]) # Assert client.request.assert_called_once() @@ -115,16 +115,16 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase): client = DocumentEmbeddingsClient() mock_response = MagicMock(spec=DocumentEmbeddingsResponse) mock_response.error = None - mock_response.chunk_ids = ["chunk1"] - + mock_response.chunks = ["chunk1"] + client.request = AsyncMock(return_value=mock_response) - + # Act await client.query( - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], timeout=60 ) - + # Assert assert client.request.call_args[1]["timeout"] == 60 @@ -136,14 +136,14 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase): client = DocumentEmbeddingsClient() mock_response = MagicMock(spec=DocumentEmbeddingsResponse) mock_response.error = None - mock_response.chunk_ids = ["test_chunk"] - + mock_response.chunks = ["test_chunk"] + client.request = AsyncMock(return_value=mock_response) - + # Act with patch('trustgraph.base.document_embeddings_client.logger') as mock_logger: - result = await client.query(vectors=[[0.1, 0.2, 0.3]]) - + result = await client.query(vector=[0.1, 0.2, 0.3]) + # Assert mock_logger.debug.assert_called_once() assert "Document embeddings response" in str(mock_logger.debug.call_args) diff --git a/tests/unit/test_clients/test_sync_document_embeddings_client.py b/tests/unit/test_clients/test_sync_document_embeddings_client.py index 5873d81c..2458f583 100644 --- a/tests/unit/test_clients/test_sync_document_embeddings_client.py +++ b/tests/unit/test_clients/test_sync_document_embeddings_client.py @@ -69,24 +69,24 @@ class TestSyncDocumentEmbeddingsClient: mock_response = MagicMock() mock_response.chunks = ["chunk1", "chunk2", "chunk3"] client.call = MagicMock(return_value=mock_response) - - vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] - + + vector = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6] + # Act result = client.request( - vectors=vectors, + vector=vector, user="test_user", collection="test_collection", limit=10, timeout=300 ) - + # Assert assert result == ["chunk1", "chunk2", "chunk3"] client.call.assert_called_once_with( user="test_user", collection="test_collection", - vectors=vectors, + vector=vector, limit=10, timeout=300 )