release/v2.4 -> master (#844)

This commit is contained in:
cybermaggedon 2026-04-22 15:19:57 +01:00 committed by GitHub
parent a24df8e990
commit 89cabee1b4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
386 changed files with 7202 additions and 5741 deletions

View file

@ -92,14 +92,13 @@ class TestQuery:
# Initialize Query with defaults
query = Query(
rag=mock_rag,
user="test_user",
workspace="test_workspace",
collection="test_collection",
verbose=False
)
# Verify initialization
assert query.rag == mock_rag
assert query.user == "test_user"
assert query.collection == "test_collection"
assert query.verbose is False
assert query.doc_limit == 20 # Default value
@ -112,7 +111,7 @@ class TestQuery:
# Initialize Query with custom doc_limit
query = Query(
rag=mock_rag,
user="custom_user",
workspace="test_workspace",
collection="custom_collection",
verbose=True,
doc_limit=50
@ -120,7 +119,6 @@ class TestQuery:
# Verify initialization
assert query.rag == mock_rag
assert query.user == "custom_user"
assert query.collection == "custom_collection"
assert query.verbose is True
assert query.doc_limit == 50
@ -137,7 +135,7 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
workspace="test_workspace",
collection="test_collection",
verbose=False
)
@ -162,7 +160,7 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
workspace="test_workspace",
collection="test_collection",
verbose=False
)
@ -184,7 +182,7 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
workspace="test_workspace",
collection="test_collection",
verbose=False
)
@ -223,7 +221,7 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
workspace="test_workspace",
collection="test_collection",
verbose=False,
doc_limit=15
@ -240,7 +238,6 @@ class TestQuery:
mock_doc_embeddings_client.query.assert_called_once_with(
vector=[0.1, 0.2, 0.3],
limit=15,
user="test_user",
collection="test_collection"
)
@ -286,7 +283,6 @@ class TestQuery:
result = await document_rag.query(
query="test query",
user="test_user",
collection="test_collection",
doc_limit=10
)
@ -304,7 +300,6 @@ class TestQuery:
mock_doc_embeddings_client.query.assert_called_once_with(
vector=[0.1, 0.2, 0.3],
limit=10,
user="test_user",
collection="test_collection"
)
@ -350,7 +345,6 @@ class TestQuery:
mock_doc_embeddings_client.query.assert_called_once_with(
vector=[[0.1, 0.2]],
limit=20, # Default doc_limit
user="trustgraph", # Default user
collection="default" # Default collection
)
@ -380,7 +374,7 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
workspace="test_workspace",
collection="test_collection",
verbose=True,
doc_limit=5
@ -453,7 +447,7 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
workspace="test_workspace",
collection="test_collection",
verbose=False
)
@ -509,7 +503,7 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
workspace="test_workspace",
collection="test_collection",
verbose=True
)
@ -558,7 +552,6 @@ class TestQuery:
result = await document_rag.query(
query=query_text,
user="research_user",
collection="ml_knowledge",
doc_limit=25
)
@ -619,7 +612,7 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
workspace="test_workspace",
collection="test_collection",
verbose=False,
doc_limit=10

View file

@ -1,6 +1,6 @@
"""
Unit test for DocumentRAG service parameter passing fix.
Tests that user and collection parameters from the message are correctly
Tests that the collection parameter from the message is correctly
passed to the DocumentRag.query() method.
"""
@ -16,13 +16,13 @@ class TestDocumentRagService:
@patch('trustgraph.retrieval.document_rag.rag.DocumentRag')
@pytest.mark.asyncio
async def test_user_and_collection_parameters_passed_to_query(self, mock_document_rag_class):
async def test_collection_parameter_passed_to_query(self, mock_document_rag_class):
"""
Test that user and collection from message are passed to DocumentRag.query().
This is a regression test for the bug where user/collection parameters
were ignored, causing wrong collection names like 'd_trustgraph_default_384'
instead of 'd_my_user_test_coll_1_384'.
Test that collection from message is passed to DocumentRag.query().
This is a regression test for the bug where the collection parameter
was ignored, causing wrong collection names like 'd_trustgraph_default_384'
instead of one that reflects the requested collection.
"""
# Setup processor
processor = Processor(
@ -30,17 +30,16 @@ class TestDocumentRagService:
id="test-processor",
doc_limit=10
)
# Setup mock DocumentRag instance
mock_rag_instance = AsyncMock()
mock_document_rag_class.return_value = mock_rag_instance
mock_rag_instance.query.return_value = ("test response", {"in_token": None, "out_token": None, "model": None})
# Setup message with custom user/collection
# Setup message with custom collection
msg = MagicMock()
msg.value.return_value = DocumentRagQuery(
query="test query",
user="my_user", # Custom user (not default "trustgraph")
collection="test_coll_1", # Custom collection (not default "default")
doc_limit=5
)
@ -64,7 +63,7 @@ class TestDocumentRagService:
# Verify: DocumentRag.query was called with correct parameters
mock_rag_instance.query.assert_called_once_with(
"test query",
user="my_user", # Must be from message, not hardcoded default
workspace=ANY, # Workspace comes from flow.workspace (mock)
collection="test_coll_1", # Must be from message, not hardcoded default
doc_limit=5,
explain_callback=ANY, # Explainability callback is always passed
@ -103,7 +102,6 @@ class TestDocumentRagService:
msg = MagicMock()
msg.value.return_value = DocumentRagQuery(
query="What is a cat?",
user="trustgraph",
collection="default",
doc_limit=10,
streaming=False # Non-streaming mode

View file

@ -78,14 +78,12 @@ class TestQuery:
# Initialize Query with defaults
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
# Verify initialization
assert query.rag == mock_rag
assert query.user == "test_user"
assert query.collection == "test_collection"
assert query.verbose is False
assert query.entity_limit == 50 # Default value
@ -101,7 +99,6 @@ class TestQuery:
# Initialize Query with custom parameters
query = Query(
rag=mock_rag,
user="custom_user",
collection="custom_collection",
verbose=True,
entity_limit=100,
@ -112,7 +109,6 @@ class TestQuery:
# Verify initialization
assert query.rag == mock_rag
assert query.user == "custom_user"
assert query.collection == "custom_collection"
assert query.verbose is True
assert query.entity_limit == 100
@ -133,7 +129,6 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
@ -156,7 +151,6 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=True
)
@ -177,7 +171,6 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
@ -201,7 +194,6 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
@ -244,7 +236,6 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False,
entity_limit=25
@ -269,7 +260,6 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
@ -277,7 +267,7 @@ class TestQuery:
result = await query.maybe_label("entity1")
assert result == "Entity One Label"
mock_cache.get.assert_called_once_with("test_user:test_collection:entity1")
mock_cache.get.assert_called_once_with("test_collection:entity1")
@pytest.mark.asyncio
async def test_maybe_label_with_label_lookup(self):
@ -295,7 +285,6 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
@ -307,13 +296,12 @@ class TestQuery:
p="http://www.w3.org/2000/01/rdf-schema#label",
o=None,
limit=1,
user="test_user",
collection="test_collection",
g=""
)
assert result == "Human Readable Label"
cache_key = "test_user:test_collection:http://example.com/entity"
cache_key = "test_collection:http://example.com/entity"
mock_cache.put.assert_called_once_with(cache_key, "Human Readable Label")
@pytest.mark.asyncio
@ -330,7 +318,6 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
@ -342,13 +329,12 @@ class TestQuery:
p="http://www.w3.org/2000/01/rdf-schema#label",
o=None,
limit=1,
user="test_user",
collection="test_collection",
g=""
)
assert result == "unlabeled_entity"
cache_key = "test_user:test_collection:unlabeled_entity"
cache_key = "test_collection:unlabeled_entity"
mock_cache.put.assert_called_once_with(cache_key, "unlabeled_entity")
@pytest.mark.asyncio
@ -375,7 +361,6 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False,
triple_limit=10
@ -388,15 +373,15 @@ class TestQuery:
mock_triples_client.query_stream.assert_any_call(
s="entity1", p=None, o=None, limit=10,
user="test_user", collection="test_collection", batch_size=20, g=""
collection="test_collection", batch_size=20, g=""
)
mock_triples_client.query_stream.assert_any_call(
s=None, p="entity1", o=None, limit=10,
user="test_user", collection="test_collection", batch_size=20, g=""
collection="test_collection", batch_size=20, g=""
)
mock_triples_client.query_stream.assert_any_call(
s=None, p=None, o="entity1", limit=10,
user="test_user", collection="test_collection", batch_size=20, g=""
collection="test_collection", batch_size=20, g=""
)
expected_subgraph = {
@ -415,7 +400,6 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
@ -435,7 +419,6 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False,
max_subgraph_size=2
@ -455,7 +438,6 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False,
max_path_length=1
@ -493,7 +475,6 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False,
max_subgraph_size=100
@ -601,7 +582,6 @@ class TestQuery:
try:
response = await graph_rag.query(
query="test query",
user="test_user",
collection="test_collection",
entity_limit=25,
triple_limit=15,

View file

@ -120,7 +120,6 @@ class TestGraphRagServiceExplainTriples:
msg = MagicMock()
msg.value.return_value = GraphRagQuery(
query="What is quantum computing?",
user="trustgraph",
collection="default",
streaming=False,
)

View file

@ -52,7 +52,6 @@ class TestGraphRagService:
msg = MagicMock()
msg.value.return_value = GraphRagQuery(
query="What is a cat?",
user="trustgraph",
collection="default",
entity_limit=50,
triple_limit=30,
@ -123,7 +122,6 @@ class TestGraphRagService:
msg = MagicMock()
msg.value.return_value = GraphRagQuery(
query="What is a cat?",
user="trustgraph",
collection="default",
entity_limit=50,
triple_limit=30,
@ -190,7 +188,6 @@ class TestGraphRagService:
msg = MagicMock()
msg.value.return_value = GraphRagQuery(
query="Test query",
user="trustgraph",
collection="default",
streaming=False
)

View file

@ -286,11 +286,11 @@ class TestNLPQueryProcessor:
}
# Act
await processor.on_schema_config(config, "v1")
await processor.on_schema_config("default", config, "v1")
# Assert
assert "test_schema" in processor.schemas
schema = processor.schemas["test_schema"]
assert "test_schema" in processor.schemas["default"]
schema = processor.schemas["default"]["test_schema"]
assert schema.name == "test_schema"
assert schema.description == "Test schema"
assert len(schema.fields) == 2
@ -308,10 +308,10 @@ class TestNLPQueryProcessor:
}
# Act
await processor.on_schema_config(config, "v1")
await processor.on_schema_config("default", config, "v1")
# Assert - bad schema should be ignored
assert "bad_schema" not in processor.schemas
assert "bad_schema" not in processor.schemas.get("default", {})
def test_processor_initialization(self, mock_pulsar_client):
"""Test processor initialization with correct specifications"""

View file

@ -101,7 +101,7 @@ def service(mock_schemas):
taskgroup=MagicMock(),
id="test-processor"
)
service.schemas = mock_schemas
service.schemas = {"default": dict(mock_schemas)}
return service
@ -109,6 +109,7 @@ def service(mock_schemas):
def mock_flow():
"""Create mock flow with prompt service"""
flow = MagicMock()
flow.workspace = "default"
prompt_request_flow = AsyncMock()
flow.return_value.request = prompt_request_flow
return flow, prompt_request_flow

View file

@ -44,7 +44,6 @@ class TestStructuredQueryProcessor:
# Arrange
request = StructuredQueryRequest(
question="Show me all customers from New York",
user="trustgraph",
collection="default"
)
@ -110,7 +109,6 @@ class TestStructuredQueryProcessor:
assert isinstance(objects_call_args, RowsQueryRequest)
assert objects_call_args.query == 'query { customers(where: {state: {eq: "NY"}}) { id name email } }'
assert objects_call_args.variables == {"state": "NY"}
assert objects_call_args.user == "trustgraph"
assert objects_call_args.collection == "default"
# Verify response