mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-29 18:36:22 +02:00
Fix/embeddings integration 2 (#670)
This commit is contained in:
parent
919b760c05
commit
4fa7cc7d7c
7 changed files with 90 additions and 77 deletions
|
|
@ -27,9 +27,13 @@ class TestDocumentRagIntegration:
|
||||||
def mock_embeddings_client(self):
|
def mock_embeddings_client(self):
|
||||||
"""Mock embeddings client that returns realistic vector embeddings"""
|
"""Mock embeddings client that returns realistic vector embeddings"""
|
||||||
client = AsyncMock()
|
client = AsyncMock()
|
||||||
|
# New batch format: [[[vectors_for_text1], ...]]
|
||||||
|
# One text input returns one vector set containing two vectors
|
||||||
client.embed.return_value = [
|
client.embed.return_value = [
|
||||||
[0.1, 0.2, 0.3, 0.4, 0.5], # Realistic 5-dimensional embedding
|
[
|
||||||
[0.6, 0.7, 0.8, 0.9, 1.0] # Second embedding for testing
|
[0.1, 0.2, 0.3, 0.4, 0.5], # First vector for text
|
||||||
|
[0.6, 0.7, 0.8, 0.9, 1.0] # Second vector for text
|
||||||
|
]
|
||||||
]
|
]
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
@ -90,7 +94,7 @@ class TestDocumentRagIntegration:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert - Verify service coordination
|
# Assert - Verify service coordination
|
||||||
mock_embeddings_client.embed.assert_called_once_with(query)
|
mock_embeddings_client.embed.assert_called_once_with([query])
|
||||||
|
|
||||||
mock_doc_embeddings_client.query.assert_called_once_with(
|
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||||
[[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]],
|
[[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]],
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,8 @@ class TestDocumentRagStreaming:
|
||||||
def mock_embeddings_client(self):
|
def mock_embeddings_client(self):
|
||||||
"""Mock embeddings client"""
|
"""Mock embeddings client"""
|
||||||
client = AsyncMock()
|
client = AsyncMock()
|
||||||
client.embed.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
# New batch format: [[[vectors_for_text1]]]
|
||||||
|
client.embed.return_value = [[[0.1, 0.2, 0.3, 0.4, 0.5]]]
|
||||||
return client
|
return client
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
|
||||||
|
|
@ -21,8 +21,12 @@ class TestGraphRagIntegration:
|
||||||
def mock_embeddings_client(self):
|
def mock_embeddings_client(self):
|
||||||
"""Mock embeddings client that returns realistic vector embeddings"""
|
"""Mock embeddings client that returns realistic vector embeddings"""
|
||||||
client = AsyncMock()
|
client = AsyncMock()
|
||||||
|
# New batch format: [[[vectors_for_text1], ...]]
|
||||||
|
# One text input returns one vector set containing one vector
|
||||||
client.embed.return_value = [
|
client.embed.return_value = [
|
||||||
[0.1, 0.2, 0.3, 0.4, 0.5], # Realistic 5-dimensional embedding
|
[
|
||||||
|
[0.1, 0.2, 0.3, 0.4, 0.5], # Vector for text
|
||||||
|
]
|
||||||
]
|
]
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
@ -120,8 +124,8 @@ class TestGraphRagIntegration:
|
||||||
|
|
||||||
# Assert - Verify service coordination
|
# Assert - Verify service coordination
|
||||||
|
|
||||||
# 1. Should compute embeddings for query
|
# 1. Should compute embeddings for query (now expects list of texts)
|
||||||
mock_embeddings_client.embed.assert_called_once_with(query)
|
mock_embeddings_client.embed.assert_called_once_with([query])
|
||||||
|
|
||||||
# 2. Should query graph embeddings to find relevant entities
|
# 2. Should query graph embeddings to find relevant entities
|
||||||
mock_graph_embeddings_client.query.assert_called_once()
|
mock_graph_embeddings_client.query.assert_called_once()
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,8 @@ class TestGraphRagStreaming:
|
||||||
def mock_embeddings_client(self):
|
def mock_embeddings_client(self):
|
||||||
"""Mock embeddings client"""
|
"""Mock embeddings client"""
|
||||||
client = AsyncMock()
|
client = AsyncMock()
|
||||||
client.embed.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
# New batch format: [[[vectors_for_text1]]]
|
||||||
|
client.embed.return_value = [[[0.1, 0.2, 0.3, 0.4, 0.5]]]
|
||||||
return client
|
return client
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ class TestGraphRagStreamingProtocol:
|
||||||
def mock_embeddings_client(self):
|
def mock_embeddings_client(self):
|
||||||
"""Mock embeddings client"""
|
"""Mock embeddings client"""
|
||||||
client = AsyncMock()
|
client = AsyncMock()
|
||||||
client.embed.return_value = [[0.1, 0.2, 0.3]]
|
client.embed.return_value = [[[0.1, 0.2, 0.3]]]
|
||||||
return client
|
return client
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
@ -197,7 +197,7 @@ class TestDocumentRagStreamingProtocol:
|
||||||
def mock_embeddings_client(self):
|
def mock_embeddings_client(self):
|
||||||
"""Mock embeddings client"""
|
"""Mock embeddings client"""
|
||||||
client = AsyncMock()
|
client = AsyncMock()
|
||||||
client.embed.return_value = [[0.1, 0.2, 0.3]]
|
client.embed.return_value = [[[0.1, 0.2, 0.3]]]
|
||||||
return client
|
return client
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
|
||||||
|
|
@ -132,9 +132,10 @@ class TestQuery:
|
||||||
mock_embeddings_client = AsyncMock()
|
mock_embeddings_client = AsyncMock()
|
||||||
mock_rag.embeddings_client = mock_embeddings_client
|
mock_rag.embeddings_client = mock_embeddings_client
|
||||||
|
|
||||||
# Mock the embed method to return test vectors
|
# Mock the embed method to return test vectors in batch format
|
||||||
|
# New format: [[[vectors_for_text1]]] - returns first text's vector set
|
||||||
expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||||
mock_embeddings_client.embed.return_value = expected_vectors
|
mock_embeddings_client.embed.return_value = [expected_vectors]
|
||||||
|
|
||||||
# Initialize Query
|
# Initialize Query
|
||||||
query = Query(
|
query = Query(
|
||||||
|
|
@ -148,10 +149,10 @@ class TestQuery:
|
||||||
test_query = "What documents are relevant?"
|
test_query = "What documents are relevant?"
|
||||||
result = await query.get_vector(test_query)
|
result = await query.get_vector(test_query)
|
||||||
|
|
||||||
# Verify embeddings client was called correctly
|
# Verify embeddings client was called correctly (now expects list)
|
||||||
mock_embeddings_client.embed.assert_called_once_with(test_query)
|
mock_embeddings_client.embed.assert_called_once_with([test_query])
|
||||||
|
|
||||||
# Verify result matches expected vectors
|
# Verify result matches expected vectors (extracted from batch)
|
||||||
assert result == expected_vectors
|
assert result == expected_vectors
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
@ -170,8 +171,9 @@ class TestQuery:
|
||||||
mock_rag.fetch_chunk = mock_fetch
|
mock_rag.fetch_chunk = mock_fetch
|
||||||
|
|
||||||
# Mock the embedding and document query responses
|
# Mock the embedding and document query responses
|
||||||
|
# New batch format: [[[vectors]]] - get_vector extracts [0]
|
||||||
test_vectors = [[0.1, 0.2, 0.3]]
|
test_vectors = [[0.1, 0.2, 0.3]]
|
||||||
mock_embeddings_client.embed.return_value = test_vectors
|
mock_embeddings_client.embed.return_value = [test_vectors]
|
||||||
|
|
||||||
# Mock document embeddings returns chunk_ids
|
# Mock document embeddings returns chunk_ids
|
||||||
test_chunk_ids = ["doc/c1", "doc/c2"]
|
test_chunk_ids = ["doc/c1", "doc/c2"]
|
||||||
|
|
@ -190,10 +192,10 @@ class TestQuery:
|
||||||
test_query = "Find relevant documents"
|
test_query = "Find relevant documents"
|
||||||
result = await query.get_docs(test_query)
|
result = await query.get_docs(test_query)
|
||||||
|
|
||||||
# Verify embeddings client was called
|
# Verify embeddings client was called (now expects list)
|
||||||
mock_embeddings_client.embed.assert_called_once_with(test_query)
|
mock_embeddings_client.embed.assert_called_once_with([test_query])
|
||||||
|
|
||||||
# Verify doc embeddings client was called correctly
|
# Verify doc embeddings client was called correctly (with extracted vectors)
|
||||||
mock_doc_embeddings_client.query.assert_called_once_with(
|
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||||
test_vectors,
|
test_vectors,
|
||||||
limit=15,
|
limit=15,
|
||||||
|
|
@ -214,11 +216,12 @@ class TestQuery:
|
||||||
mock_doc_embeddings_client = AsyncMock()
|
mock_doc_embeddings_client = AsyncMock()
|
||||||
|
|
||||||
# Mock embeddings and document embeddings responses
|
# Mock embeddings and document embeddings responses
|
||||||
|
# New batch format: [[[vectors]]] - get_vector extracts [0]
|
||||||
test_vectors = [[0.1, 0.2, 0.3]]
|
test_vectors = [[0.1, 0.2, 0.3]]
|
||||||
test_chunk_ids = ["doc/c3", "doc/c4"]
|
test_chunk_ids = ["doc/c3", "doc/c4"]
|
||||||
expected_response = "This is the document RAG response"
|
expected_response = "This is the document RAG response"
|
||||||
|
|
||||||
mock_embeddings_client.embed.return_value = test_vectors
|
mock_embeddings_client.embed.return_value = [test_vectors]
|
||||||
mock_doc_embeddings_client.query.return_value = test_chunk_ids
|
mock_doc_embeddings_client.query.return_value = test_chunk_ids
|
||||||
mock_prompt_client.document_prompt.return_value = expected_response
|
mock_prompt_client.document_prompt.return_value = expected_response
|
||||||
|
|
||||||
|
|
@ -239,10 +242,10 @@ class TestQuery:
|
||||||
doc_limit=10
|
doc_limit=10
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify embeddings client was called
|
# Verify embeddings client was called (now expects list)
|
||||||
mock_embeddings_client.embed.assert_called_once_with("test query")
|
mock_embeddings_client.embed.assert_called_once_with(["test query"])
|
||||||
|
|
||||||
# Verify doc embeddings client was called
|
# Verify doc embeddings client was called (with extracted vectors)
|
||||||
mock_doc_embeddings_client.query.assert_called_once_with(
|
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||||
test_vectors,
|
test_vectors,
|
||||||
limit=10,
|
limit=10,
|
||||||
|
|
@ -270,8 +273,8 @@ class TestQuery:
|
||||||
mock_embeddings_client = AsyncMock()
|
mock_embeddings_client = AsyncMock()
|
||||||
mock_doc_embeddings_client = AsyncMock()
|
mock_doc_embeddings_client = AsyncMock()
|
||||||
|
|
||||||
# Mock responses
|
# Mock responses (batch format)
|
||||||
mock_embeddings_client.embed.return_value = [[0.1, 0.2]]
|
mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]]
|
||||||
mock_doc_embeddings_client.query.return_value = ["doc/c5"]
|
mock_doc_embeddings_client.query.return_value = ["doc/c5"]
|
||||||
mock_prompt_client.document_prompt.return_value = "Default response"
|
mock_prompt_client.document_prompt.return_value = "Default response"
|
||||||
|
|
||||||
|
|
@ -286,7 +289,7 @@ class TestQuery:
|
||||||
# Call DocumentRag.query with minimal parameters
|
# Call DocumentRag.query with minimal parameters
|
||||||
result = await document_rag.query("simple query")
|
result = await document_rag.query("simple query")
|
||||||
|
|
||||||
# Verify default parameters were used
|
# Verify default parameters were used (vectors extracted from batch)
|
||||||
mock_doc_embeddings_client.query.assert_called_once_with(
|
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||||
[[0.1, 0.2]],
|
[[0.1, 0.2]],
|
||||||
limit=20, # Default doc_limit
|
limit=20, # Default doc_limit
|
||||||
|
|
@ -311,8 +314,8 @@ class TestQuery:
|
||||||
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
|
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
|
||||||
mock_rag.fetch_chunk = mock_fetch
|
mock_rag.fetch_chunk = mock_fetch
|
||||||
|
|
||||||
# Mock responses
|
# Mock responses (batch format)
|
||||||
mock_embeddings_client.embed.return_value = [[0.7, 0.8]]
|
mock_embeddings_client.embed.return_value = [[[0.7, 0.8]]]
|
||||||
mock_doc_embeddings_client.query.return_value = ["doc/c6"]
|
mock_doc_embeddings_client.query.return_value = ["doc/c6"]
|
||||||
|
|
||||||
# Initialize Query with verbose=True
|
# Initialize Query with verbose=True
|
||||||
|
|
@ -327,8 +330,8 @@ class TestQuery:
|
||||||
# Call get_docs
|
# Call get_docs
|
||||||
result = await query.get_docs("verbose test")
|
result = await query.get_docs("verbose test")
|
||||||
|
|
||||||
# Verify calls were made
|
# Verify calls were made (now expects list)
|
||||||
mock_embeddings_client.embed.assert_called_once_with("verbose test")
|
mock_embeddings_client.embed.assert_called_once_with(["verbose test"])
|
||||||
mock_doc_embeddings_client.query.assert_called_once()
|
mock_doc_embeddings_client.query.assert_called_once()
|
||||||
|
|
||||||
# Verify result contains fetched content
|
# Verify result contains fetched content
|
||||||
|
|
@ -342,8 +345,8 @@ class TestQuery:
|
||||||
mock_embeddings_client = AsyncMock()
|
mock_embeddings_client = AsyncMock()
|
||||||
mock_doc_embeddings_client = AsyncMock()
|
mock_doc_embeddings_client = AsyncMock()
|
||||||
|
|
||||||
# Mock responses
|
# Mock responses (batch format)
|
||||||
mock_embeddings_client.embed.return_value = [[0.3, 0.4]]
|
mock_embeddings_client.embed.return_value = [[[0.3, 0.4]]]
|
||||||
mock_doc_embeddings_client.query.return_value = ["doc/c7"]
|
mock_doc_embeddings_client.query.return_value = ["doc/c7"]
|
||||||
mock_prompt_client.document_prompt.return_value = "Verbose RAG response"
|
mock_prompt_client.document_prompt.return_value = "Verbose RAG response"
|
||||||
|
|
||||||
|
|
@ -359,8 +362,8 @@ class TestQuery:
|
||||||
# Call DocumentRag.query
|
# Call DocumentRag.query
|
||||||
result = await document_rag.query("verbose query test")
|
result = await document_rag.query("verbose query test")
|
||||||
|
|
||||||
# Verify all clients were called
|
# Verify all clients were called (now expects list)
|
||||||
mock_embeddings_client.embed.assert_called_once_with("verbose query test")
|
mock_embeddings_client.embed.assert_called_once_with(["verbose query test"])
|
||||||
mock_doc_embeddings_client.query.assert_called_once()
|
mock_doc_embeddings_client.query.assert_called_once()
|
||||||
|
|
||||||
# Verify prompt client was called with fetched content
|
# Verify prompt client was called with fetched content
|
||||||
|
|
@ -385,8 +388,8 @@ class TestQuery:
|
||||||
return f"Content for {chunk_id}"
|
return f"Content for {chunk_id}"
|
||||||
mock_rag.fetch_chunk = mock_fetch
|
mock_rag.fetch_chunk = mock_fetch
|
||||||
|
|
||||||
# Mock responses - empty chunk_id list
|
# Mock responses - empty chunk_id list (batch format)
|
||||||
mock_embeddings_client.embed.return_value = [[0.1, 0.2]]
|
mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]]
|
||||||
mock_doc_embeddings_client.query.return_value = [] # No chunk_ids found
|
mock_doc_embeddings_client.query.return_value = [] # No chunk_ids found
|
||||||
|
|
||||||
# Initialize Query
|
# Initialize Query
|
||||||
|
|
@ -400,8 +403,8 @@ class TestQuery:
|
||||||
# Call get_docs
|
# Call get_docs
|
||||||
result = await query.get_docs("query with no results")
|
result = await query.get_docs("query with no results")
|
||||||
|
|
||||||
# Verify calls were made
|
# Verify calls were made (now expects list)
|
||||||
mock_embeddings_client.embed.assert_called_once_with("query with no results")
|
mock_embeddings_client.embed.assert_called_once_with(["query with no results"])
|
||||||
mock_doc_embeddings_client.query.assert_called_once()
|
mock_doc_embeddings_client.query.assert_called_once()
|
||||||
|
|
||||||
# Verify empty result is returned
|
# Verify empty result is returned
|
||||||
|
|
@ -415,8 +418,8 @@ class TestQuery:
|
||||||
mock_embeddings_client = AsyncMock()
|
mock_embeddings_client = AsyncMock()
|
||||||
mock_doc_embeddings_client = AsyncMock()
|
mock_doc_embeddings_client = AsyncMock()
|
||||||
|
|
||||||
# Mock responses - no chunk_ids found
|
# Mock responses - no chunk_ids found (batch format)
|
||||||
mock_embeddings_client.embed.return_value = [[0.5, 0.6]]
|
mock_embeddings_client.embed.return_value = [[[0.5, 0.6]]]
|
||||||
mock_doc_embeddings_client.query.return_value = [] # Empty chunk_id list
|
mock_doc_embeddings_client.query.return_value = [] # Empty chunk_id list
|
||||||
mock_prompt_client.document_prompt.return_value = "No documents found response"
|
mock_prompt_client.document_prompt.return_value = "No documents found response"
|
||||||
|
|
||||||
|
|
@ -448,9 +451,9 @@ class TestQuery:
|
||||||
mock_embeddings_client = AsyncMock()
|
mock_embeddings_client = AsyncMock()
|
||||||
mock_rag.embeddings_client = mock_embeddings_client
|
mock_rag.embeddings_client = mock_embeddings_client
|
||||||
|
|
||||||
# Mock the embed method
|
# Mock the embed method (batch format)
|
||||||
expected_vectors = [[0.9, 1.0, 1.1]]
|
expected_vectors = [[0.9, 1.0, 1.1]]
|
||||||
mock_embeddings_client.embed.return_value = expected_vectors
|
mock_embeddings_client.embed.return_value = [expected_vectors]
|
||||||
|
|
||||||
# Initialize Query with verbose=True
|
# Initialize Query with verbose=True
|
||||||
query = Query(
|
query = Query(
|
||||||
|
|
@ -463,10 +466,10 @@ class TestQuery:
|
||||||
# Call get_vector
|
# Call get_vector
|
||||||
result = await query.get_vector("verbose vector test")
|
result = await query.get_vector("verbose vector test")
|
||||||
|
|
||||||
# Verify embeddings client was called
|
# Verify embeddings client was called (now expects list)
|
||||||
mock_embeddings_client.embed.assert_called_once_with("verbose vector test")
|
mock_embeddings_client.embed.assert_called_once_with(["verbose vector test"])
|
||||||
|
|
||||||
# Verify result
|
# Verify result (extracted from batch)
|
||||||
assert result == expected_vectors
|
assert result == expected_vectors
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
@ -477,13 +480,13 @@ class TestQuery:
|
||||||
mock_embeddings_client = AsyncMock()
|
mock_embeddings_client = AsyncMock()
|
||||||
mock_doc_embeddings_client = AsyncMock()
|
mock_doc_embeddings_client = AsyncMock()
|
||||||
|
|
||||||
# Mock realistic responses
|
# Mock realistic responses (batch format)
|
||||||
query_text = "What is machine learning?"
|
query_text = "What is machine learning?"
|
||||||
query_vectors = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
query_vectors = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||||
retrieved_chunk_ids = ["doc/ml1", "doc/ml2", "doc/ml3"]
|
retrieved_chunk_ids = ["doc/ml1", "doc/ml2", "doc/ml3"]
|
||||||
final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed."
|
final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed."
|
||||||
|
|
||||||
mock_embeddings_client.embed.return_value = query_vectors
|
mock_embeddings_client.embed.return_value = [query_vectors]
|
||||||
mock_doc_embeddings_client.query.return_value = retrieved_chunk_ids
|
mock_doc_embeddings_client.query.return_value = retrieved_chunk_ids
|
||||||
mock_prompt_client.document_prompt.return_value = final_response
|
mock_prompt_client.document_prompt.return_value = final_response
|
||||||
|
|
||||||
|
|
@ -504,8 +507,8 @@ class TestQuery:
|
||||||
doc_limit=25
|
doc_limit=25
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify complete pipeline execution
|
# Verify complete pipeline execution (now expects list)
|
||||||
mock_embeddings_client.embed.assert_called_once_with(query_text)
|
mock_embeddings_client.embed.assert_called_once_with([query_text])
|
||||||
|
|
||||||
mock_doc_embeddings_client.query.assert_called_once_with(
|
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||||
query_vectors,
|
query_vectors,
|
||||||
|
|
|
||||||
|
|
@ -127,9 +127,9 @@ class TestQuery:
|
||||||
mock_embeddings_client = AsyncMock()
|
mock_embeddings_client = AsyncMock()
|
||||||
mock_rag.embeddings_client = mock_embeddings_client
|
mock_rag.embeddings_client = mock_embeddings_client
|
||||||
|
|
||||||
# Mock the embed method to return test vectors
|
# Mock the embed method to return test vectors (batch format)
|
||||||
expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||||
mock_embeddings_client.embed.return_value = expected_vectors
|
mock_embeddings_client.embed.return_value = [expected_vectors]
|
||||||
|
|
||||||
# Initialize Query
|
# Initialize Query
|
||||||
query = Query(
|
query = Query(
|
||||||
|
|
@ -143,10 +143,10 @@ class TestQuery:
|
||||||
test_query = "What is the capital of France?"
|
test_query = "What is the capital of France?"
|
||||||
result = await query.get_vector(test_query)
|
result = await query.get_vector(test_query)
|
||||||
|
|
||||||
# Verify embeddings client was called correctly
|
# Verify embeddings client was called correctly (now expects list)
|
||||||
mock_embeddings_client.embed.assert_called_once_with(test_query)
|
mock_embeddings_client.embed.assert_called_once_with([test_query])
|
||||||
|
|
||||||
# Verify result matches expected vectors
|
# Verify result matches expected vectors (extracted from batch)
|
||||||
assert result == expected_vectors
|
assert result == expected_vectors
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
@ -157,9 +157,9 @@ class TestQuery:
|
||||||
mock_embeddings_client = AsyncMock()
|
mock_embeddings_client = AsyncMock()
|
||||||
mock_rag.embeddings_client = mock_embeddings_client
|
mock_rag.embeddings_client = mock_embeddings_client
|
||||||
|
|
||||||
# Mock the embed method
|
# Mock the embed method (batch format)
|
||||||
expected_vectors = [[0.7, 0.8, 0.9]]
|
expected_vectors = [[0.7, 0.8, 0.9]]
|
||||||
mock_embeddings_client.embed.return_value = expected_vectors
|
mock_embeddings_client.embed.return_value = [expected_vectors]
|
||||||
|
|
||||||
# Initialize Query with verbose=True
|
# Initialize Query with verbose=True
|
||||||
query = Query(
|
query = Query(
|
||||||
|
|
@ -173,10 +173,10 @@ class TestQuery:
|
||||||
test_query = "Test query for embeddings"
|
test_query = "Test query for embeddings"
|
||||||
result = await query.get_vector(test_query)
|
result = await query.get_vector(test_query)
|
||||||
|
|
||||||
# Verify embeddings client was called correctly
|
# Verify embeddings client was called correctly (now expects list)
|
||||||
mock_embeddings_client.embed.assert_called_once_with(test_query)
|
mock_embeddings_client.embed.assert_called_once_with([test_query])
|
||||||
|
|
||||||
# Verify result matches expected vectors
|
# Verify result matches expected vectors (extracted from batch)
|
||||||
assert result == expected_vectors
|
assert result == expected_vectors
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
@ -189,9 +189,9 @@ class TestQuery:
|
||||||
mock_rag.embeddings_client = mock_embeddings_client
|
mock_rag.embeddings_client = mock_embeddings_client
|
||||||
mock_rag.graph_embeddings_client = mock_graph_embeddings_client
|
mock_rag.graph_embeddings_client = mock_graph_embeddings_client
|
||||||
|
|
||||||
# Mock the embedding and entity query responses
|
# Mock the embedding and entity query responses (batch format)
|
||||||
test_vectors = [[0.1, 0.2, 0.3]]
|
test_vectors = [[0.1, 0.2, 0.3]]
|
||||||
mock_embeddings_client.embed.return_value = test_vectors
|
mock_embeddings_client.embed.return_value = [test_vectors]
|
||||||
|
|
||||||
# Mock entity objects that have string representation
|
# Mock entity objects that have string representation
|
||||||
mock_entity1 = MagicMock()
|
mock_entity1 = MagicMock()
|
||||||
|
|
@ -213,10 +213,10 @@ class TestQuery:
|
||||||
test_query = "Find related entities"
|
test_query = "Find related entities"
|
||||||
result = await query.get_entities(test_query)
|
result = await query.get_entities(test_query)
|
||||||
|
|
||||||
# Verify embeddings client was called
|
# Verify embeddings client was called (now expects list)
|
||||||
mock_embeddings_client.embed.assert_called_once_with(test_query)
|
mock_embeddings_client.embed.assert_called_once_with([test_query])
|
||||||
|
|
||||||
# Verify graph embeddings client was called correctly
|
# Verify graph embeddings client was called correctly (with extracted vectors)
|
||||||
mock_graph_embeddings_client.query.assert_called_once_with(
|
mock_graph_embeddings_client.query.assert_called_once_with(
|
||||||
vectors=test_vectors,
|
vectors=test_vectors,
|
||||||
limit=25,
|
limit=25,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue