mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-05 13:22:37 +02:00
Fix/document rag (#506)
* Fix missing document RAG user/collection params * Added test
This commit is contained in:
parent
266454e75f
commit
6ac8a7c2d9
2 changed files with 83 additions and 1 deletions
77
tests/unit/test_retrieval/test_document_rag_service.py
Normal file
77
tests/unit/test_retrieval/test_document_rag_service.py
Normal file
|
|
@ -0,0 +1,77 @@
|
||||||
|
"""
|
||||||
|
Unit test for DocumentRAG service parameter passing fix.
|
||||||
|
Tests that user and collection parameters from the message are correctly
|
||||||
|
passed to the DocumentRag.query() method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, AsyncMock, patch
|
||||||
|
|
||||||
|
from trustgraph.retrieval.document_rag.rag import Processor
|
||||||
|
from trustgraph.schema import DocumentRagQuery, DocumentRagResponse
|
||||||
|
|
||||||
|
|
||||||
|
class TestDocumentRagService:
|
||||||
|
"""Test DocumentRAG service parameter passing"""
|
||||||
|
|
||||||
|
@patch('trustgraph.retrieval.document_rag.rag.DocumentRag')
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_user_and_collection_parameters_passed_to_query(self, mock_document_rag_class):
|
||||||
|
"""
|
||||||
|
Test that user and collection from message are passed to DocumentRag.query().
|
||||||
|
|
||||||
|
This is a regression test for the bug where user/collection parameters
|
||||||
|
were ignored, causing wrong collection names like 'd_trustgraph_default_384'
|
||||||
|
instead of 'd_my_user_test_coll_1_384'.
|
||||||
|
"""
|
||||||
|
# Setup processor
|
||||||
|
processor = Processor(
|
||||||
|
taskgroup=MagicMock(),
|
||||||
|
id="test-processor",
|
||||||
|
doc_limit=10
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup mock DocumentRag instance
|
||||||
|
mock_rag_instance = AsyncMock()
|
||||||
|
mock_document_rag_class.return_value = mock_rag_instance
|
||||||
|
mock_rag_instance.query.return_value = "test response"
|
||||||
|
|
||||||
|
# Setup message with custom user/collection
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.value.return_value = DocumentRagQuery(
|
||||||
|
query="test query",
|
||||||
|
user="my_user", # Custom user (not default "trustgraph")
|
||||||
|
collection="test_coll_1", # Custom collection (not default "default")
|
||||||
|
doc_limit=5
|
||||||
|
)
|
||||||
|
msg.properties.return_value = {"id": "test-id"}
|
||||||
|
|
||||||
|
# Setup flow mock
|
||||||
|
consumer = MagicMock()
|
||||||
|
flow = MagicMock()
|
||||||
|
|
||||||
|
# Mock flow to return AsyncMock for clients and response producer
|
||||||
|
mock_producer = AsyncMock()
|
||||||
|
def flow_router(service_name):
|
||||||
|
if service_name == "response":
|
||||||
|
return mock_producer
|
||||||
|
return AsyncMock() # embeddings, doc-embeddings, prompt clients
|
||||||
|
flow.side_effect = flow_router
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
await processor.on_request(msg, consumer, flow)
|
||||||
|
|
||||||
|
# Verify: DocumentRag.query was called with correct parameters
|
||||||
|
mock_rag_instance.query.assert_called_once_with(
|
||||||
|
"test query",
|
||||||
|
user="my_user", # Must be from message, not hardcoded default
|
||||||
|
collection="test_coll_1", # Must be from message, not hardcoded default
|
||||||
|
doc_limit=5
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify response was sent
|
||||||
|
mock_producer.send.assert_called_once()
|
||||||
|
sent_response = mock_producer.send.call_args[0][0]
|
||||||
|
assert isinstance(sent_response, DocumentRagResponse)
|
||||||
|
assert sent_response.response == "test response"
|
||||||
|
assert sent_response.error is None
|
||||||
|
|
@ -92,7 +92,12 @@ class Processor(FlowProcessor):
|
||||||
else:
|
else:
|
||||||
doc_limit = self.doc_limit
|
doc_limit = self.doc_limit
|
||||||
|
|
||||||
response = await self.rag.query(v.query, doc_limit=doc_limit)
|
response = await self.rag.query(
|
||||||
|
v.query,
|
||||||
|
user=v.user,
|
||||||
|
collection=v.collection,
|
||||||
|
doc_limit=doc_limit
|
||||||
|
)
|
||||||
|
|
||||||
await flow("response").send(
|
await flow("response").send(
|
||||||
DocumentRagResponse(
|
DocumentRagResponse(
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue