Update tests

This commit is contained in:
Cyber MacGeddon 2026-03-09 10:01:55 +00:00
parent 356d7f75ac
commit dcee1b8de2

View file

@ -175,9 +175,14 @@ class TestQuery:
test_vectors = [[0.1, 0.2, 0.3]] test_vectors = [[0.1, 0.2, 0.3]]
mock_embeddings_client.embed.return_value = [test_vectors] mock_embeddings_client.embed.return_value = [test_vectors]
# Mock document embeddings returns chunk_ids # Mock document embeddings returns ChunkMatch objects
test_chunk_ids = ["doc/c1", "doc/c2"] mock_match1 = MagicMock()
mock_doc_embeddings_client.query.return_value = test_chunk_ids mock_match1.chunk_id = "doc/c1"
mock_match1.score = 0.95
mock_match2 = MagicMock()
mock_match2.chunk_id = "doc/c2"
mock_match2.score = 0.85
mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2]
# Initialize Query # Initialize Query
query = Query( query = Query(
@ -195,9 +200,9 @@ class TestQuery:
# Verify embeddings client was called (now expects list) # Verify embeddings client was called (now expects list)
mock_embeddings_client.embed.assert_called_once_with([test_query]) mock_embeddings_client.embed.assert_called_once_with([test_query])
# Verify doc embeddings client was called correctly (with extracted vectors) # Verify doc embeddings client was called correctly (with extracted vector)
mock_doc_embeddings_client.query.assert_called_once_with( mock_doc_embeddings_client.query.assert_called_once_with(
test_vectors, vector=test_vectors,
limit=15, limit=15,
user="test_user", user="test_user",
collection="test_collection" collection="test_collection"
@ -218,11 +223,16 @@ class TestQuery:
# Mock embeddings and document embeddings responses # Mock embeddings and document embeddings responses
# New batch format: [[[vectors]]] - get_vector extracts [0] # New batch format: [[[vectors]]] - get_vector extracts [0]
test_vectors = [[0.1, 0.2, 0.3]] test_vectors = [[0.1, 0.2, 0.3]]
test_chunk_ids = ["doc/c3", "doc/c4"] mock_match1 = MagicMock()
mock_match1.chunk_id = "doc/c3"
mock_match1.score = 0.9
mock_match2 = MagicMock()
mock_match2.chunk_id = "doc/c4"
mock_match2.score = 0.8
expected_response = "This is the document RAG response" expected_response = "This is the document RAG response"
mock_embeddings_client.embed.return_value = [test_vectors] mock_embeddings_client.embed.return_value = [test_vectors]
mock_doc_embeddings_client.query.return_value = test_chunk_ids mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2]
mock_prompt_client.document_prompt.return_value = expected_response mock_prompt_client.document_prompt.return_value = expected_response
# Initialize DocumentRag # Initialize DocumentRag
@ -245,9 +255,9 @@ class TestQuery:
# Verify embeddings client was called (now expects list) # Verify embeddings client was called (now expects list)
mock_embeddings_client.embed.assert_called_once_with(["test query"]) mock_embeddings_client.embed.assert_called_once_with(["test query"])
# Verify doc embeddings client was called (with extracted vectors) # Verify doc embeddings client was called (with extracted vector)
mock_doc_embeddings_client.query.assert_called_once_with( mock_doc_embeddings_client.query.assert_called_once_with(
test_vectors, vector=test_vectors,
limit=10, limit=10,
user="test_user", user="test_user",
collection="test_collection" collection="test_collection"
@ -275,7 +285,10 @@ class TestQuery:
# Mock responses (batch format) # Mock responses (batch format)
mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]] mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]]
mock_doc_embeddings_client.query.return_value = ["doc/c5"] mock_match = MagicMock()
mock_match.chunk_id = "doc/c5"
mock_match.score = 0.9
mock_doc_embeddings_client.query.return_value = [mock_match]
mock_prompt_client.document_prompt.return_value = "Default response" mock_prompt_client.document_prompt.return_value = "Default response"
# Initialize DocumentRag # Initialize DocumentRag
@ -289,9 +302,9 @@ class TestQuery:
# Call DocumentRag.query with minimal parameters # Call DocumentRag.query with minimal parameters
result = await document_rag.query("simple query") result = await document_rag.query("simple query")
# Verify default parameters were used (vectors extracted from batch) # Verify default parameters were used (vector extracted from batch)
mock_doc_embeddings_client.query.assert_called_once_with( mock_doc_embeddings_client.query.assert_called_once_with(
[[0.1, 0.2]], vector=[[0.1, 0.2]],
limit=20, # Default doc_limit limit=20, # Default doc_limit
user="trustgraph", # Default user user="trustgraph", # Default user
collection="default" # Default collection collection="default" # Default collection
@ -316,7 +329,10 @@ class TestQuery:
# Mock responses (batch format) # Mock responses (batch format)
mock_embeddings_client.embed.return_value = [[[0.7, 0.8]]] mock_embeddings_client.embed.return_value = [[[0.7, 0.8]]]
mock_doc_embeddings_client.query.return_value = ["doc/c6"] mock_match = MagicMock()
mock_match.chunk_id = "doc/c6"
mock_match.score = 0.88
mock_doc_embeddings_client.query.return_value = [mock_match]
# Initialize Query with verbose=True # Initialize Query with verbose=True
query = Query( query = Query(
@ -347,7 +363,10 @@ class TestQuery:
# Mock responses (batch format) # Mock responses (batch format)
mock_embeddings_client.embed.return_value = [[[0.3, 0.4]]] mock_embeddings_client.embed.return_value = [[[0.3, 0.4]]]
mock_doc_embeddings_client.query.return_value = ["doc/c7"] mock_match = MagicMock()
mock_match.chunk_id = "doc/c7"
mock_match.score = 0.92
mock_doc_embeddings_client.query.return_value = [mock_match]
mock_prompt_client.document_prompt.return_value = "Verbose RAG response" mock_prompt_client.document_prompt.return_value = "Verbose RAG response"
# Initialize DocumentRag with verbose=True # Initialize DocumentRag with verbose=True
@ -487,7 +506,13 @@ class TestQuery:
final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed." final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed."
mock_embeddings_client.embed.return_value = [query_vectors] mock_embeddings_client.embed.return_value = [query_vectors]
mock_doc_embeddings_client.query.return_value = retrieved_chunk_ids mock_matches = []
for chunk_id in retrieved_chunk_ids:
mock_match = MagicMock()
mock_match.chunk_id = chunk_id
mock_match.score = 0.9
mock_matches.append(mock_match)
mock_doc_embeddings_client.query.return_value = mock_matches
mock_prompt_client.document_prompt.return_value = final_response mock_prompt_client.document_prompt.return_value = final_response
# Initialize DocumentRag # Initialize DocumentRag
@ -511,7 +536,7 @@ class TestQuery:
mock_embeddings_client.embed.assert_called_once_with([query_text]) mock_embeddings_client.embed.assert_called_once_with([query_text])
mock_doc_embeddings_client.query.assert_called_once_with( mock_doc_embeddings_client.query.assert_called_once_with(
query_vectors, vector=query_vectors,
limit=25, limit=25,
user="research_user", user="research_user",
collection="ml_knowledge" collection="ml_knowledge"