diff --git a/tests/integration/test_document_rag_integration.py b/tests/integration/test_document_rag_integration.py index 99e25ed5..80520429 100644 --- a/tests/integration/test_document_rag_integration.py +++ b/tests/integration/test_document_rag_integration.py @@ -9,6 +9,7 @@ Following the TEST_STRATEGY.md approach for integration testing. import pytest from unittest.mock import AsyncMock, MagicMock from trustgraph.retrieval.document_rag.document_rag import DocumentRag +from trustgraph.schema import ChunkMatch # Sample chunk content for testing - maps chunk_id to content @@ -39,10 +40,14 @@ class TestDocumentRagIntegration: @pytest.fixture def mock_doc_embeddings_client(self): - """Mock document embeddings client that returns chunk IDs""" + """Mock document embeddings client that returns chunk matches""" client = AsyncMock() - # Now returns chunk_ids instead of actual content - client.query.return_value = ["doc/c1", "doc/c2", "doc/c3"] + # Returns ChunkMatch objects with chunk_id and score + client.query.return_value = [ + ChunkMatch(chunk_id="doc/c1", score=0.95), + ChunkMatch(chunk_id="doc/c2", score=0.90), + ChunkMatch(chunk_id="doc/c3", score=0.85) + ] return client @pytest.fixture diff --git a/tests/integration/test_document_rag_streaming_integration.py b/tests/integration/test_document_rag_streaming_integration.py index db79ebac..dad30a8f 100644 --- a/tests/integration/test_document_rag_streaming_integration.py +++ b/tests/integration/test_document_rag_streaming_integration.py @@ -8,6 +8,7 @@ response delivery through the complete pipeline. import pytest from unittest.mock import AsyncMock from trustgraph.retrieval.document_rag.document_rag import DocumentRag +from trustgraph.schema import ChunkMatch from tests.utils.streaming_assertions import ( assert_streaming_chunks_valid, assert_callback_invoked, @@ -36,10 +37,14 @@ class TestDocumentRagStreaming: @pytest.fixture def mock_doc_embeddings_client(self): - """Mock document embeddings client that returns chunk IDs""" + """Mock document embeddings client that returns chunk matches""" client = AsyncMock() - # Now returns chunk_ids instead of actual content - client.query.return_value = ["doc/c1", "doc/c2", "doc/c3"] + # Returns ChunkMatch objects with chunk_id and score + client.query.return_value = [ + ChunkMatch(chunk_id="doc/c1", score=0.95), + ChunkMatch(chunk_id="doc/c2", score=0.90), + ChunkMatch(chunk_id="doc/c3", score=0.85) + ] return client @pytest.fixture diff --git a/tests/integration/test_graph_rag_integration.py b/tests/integration/test_graph_rag_integration.py index 94e8cf08..d3b6c2ba 100644 --- a/tests/integration/test_graph_rag_integration.py +++ b/tests/integration/test_graph_rag_integration.py @@ -11,6 +11,7 @@ NOTE: This is the first integration test file for GraphRAG (previously had only import pytest from unittest.mock import AsyncMock, MagicMock from trustgraph.retrieval.graph_rag.graph_rag import GraphRag +from trustgraph.schema import EntityMatch, Term @pytest.mark.integration @@ -35,9 +36,9 @@ class TestGraphRagIntegration: """Mock graph embeddings client that returns realistic entities""" client = AsyncMock() client.query.return_value = [ - "http://trustgraph.ai/e/machine-learning", - "http://trustgraph.ai/e/artificial-intelligence", - "http://trustgraph.ai/e/neural-networks" + EntityMatch(entity=Term(value="http://trustgraph.ai/e/machine-learning", is_uri=True), score=0.95), + EntityMatch(entity=Term(value="http://trustgraph.ai/e/artificial-intelligence", is_uri=True), score=0.90), + EntityMatch(entity=Term(value="http://trustgraph.ai/e/neural-networks", is_uri=True), score=0.85) ] return client diff --git a/tests/integration/test_graph_rag_streaming_integration.py b/tests/integration/test_graph_rag_streaming_integration.py index f4d8ce8b..d1eb0099 100644 --- a/tests/integration/test_graph_rag_streaming_integration.py +++ b/tests/integration/test_graph_rag_streaming_integration.py @@ -8,6 +8,7 @@ response delivery through the complete pipeline. import pytest from unittest.mock import AsyncMock, MagicMock from trustgraph.retrieval.graph_rag.graph_rag import GraphRag +from trustgraph.schema import EntityMatch, Term from tests.utils.streaming_assertions import ( assert_streaming_chunks_valid, assert_rag_streaming_chunks, @@ -33,7 +34,7 @@ class TestGraphRagStreaming: """Mock graph embeddings client""" client = AsyncMock() client.query.return_value = [ - "http://trustgraph.ai/e/machine-learning", + EntityMatch(entity=Term(value="http://trustgraph.ai/e/machine-learning", is_uri=True), score=0.95), ] return client diff --git a/tests/integration/test_kg_extract_store_integration.py b/tests/integration/test_kg_extract_store_integration.py index 2baa1d4d..c390c139 100644 --- a/tests/integration/test_kg_extract_store_integration.py +++ b/tests/integration/test_kg_extract_store_integration.py @@ -411,7 +411,7 @@ class TestKnowledgeGraphPipelineIntegration: entities=[ EntityEmbeddings( entity=Term(type=IRI, iri="http://example.org/entity"), - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) ] ) diff --git a/tests/integration/test_rag_streaming_protocol.py b/tests/integration/test_rag_streaming_protocol.py index 19f2cf35..3c484f27 100644 --- a/tests/integration/test_rag_streaming_protocol.py +++ b/tests/integration/test_rag_streaming_protocol.py @@ -9,6 +9,7 @@ import pytest from unittest.mock import AsyncMock, MagicMock, call from trustgraph.retrieval.graph_rag.graph_rag import GraphRag from trustgraph.retrieval.document_rag.document_rag import DocumentRag +from trustgraph.schema import EntityMatch, ChunkMatch, Term class TestGraphRagStreamingProtocol: @@ -25,7 +26,10 @@ class TestGraphRagStreamingProtocol: def mock_graph_embeddings_client(self): """Mock graph embeddings client""" client = AsyncMock() - client.query.return_value = ["entity1", "entity2"] + client.query.return_value = [ + EntityMatch(entity=Term(value="entity1", is_uri=True), score=0.95), + EntityMatch(entity=Term(value="entity2", is_uri=True), score=0.90) + ] return client @pytest.fixture @@ -202,9 +206,12 @@ class TestDocumentRagStreamingProtocol: @pytest.fixture def mock_doc_embeddings_client(self): - """Mock document embeddings client that returns chunk IDs""" + """Mock document embeddings client that returns chunk matches""" client = AsyncMock() - client.query.return_value = ["doc/c1", "doc/c2"] + client.query.return_value = [ + ChunkMatch(chunk_id="doc/c1", score=0.95), + ChunkMatch(chunk_id="doc/c2", score=0.90) + ] return client @pytest.fixture