mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-27 01:16:22 +02:00
Extend use of user + collection fields (#503)
* Collection+user fields in structured query * User/collection in structured query & agent
This commit is contained in:
parent
a92050c411
commit
f22bf13aa6
15 changed files with 122 additions and 45 deletions
|
|
@ -59,8 +59,10 @@ class TestAgentStructuredQueryIntegration:
|
|||
# Create agent request
|
||||
request = AgentRequest(
|
||||
question="I need to find all customers from New York. Use the structured query tool to get this information.",
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
state="",
|
||||
group=None,
|
||||
history=[],
|
||||
user="test_user"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
|
|
@ -118,7 +120,8 @@ Args: {
|
|||
# Verify structured query was called
|
||||
mock_structured_client.structured_query.assert_called_once()
|
||||
call_args = mock_structured_client.structured_query.call_args
|
||||
question_arg = call_args[0][0] # positional argument
|
||||
# Check keyword arguments
|
||||
question_arg = call_args.kwargs.get("question") or call_args[1].get("question")
|
||||
assert "customers" in question_arg.lower()
|
||||
assert "new york" in question_arg.lower()
|
||||
|
||||
|
|
@ -140,8 +143,10 @@ Args: {
|
|||
|
||||
request = AgentRequest(
|
||||
question="Find data from a table that doesn't exist using structured query.",
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
state="",
|
||||
group=None,
|
||||
history=[],
|
||||
user="test_user"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
|
|
@ -198,8 +203,9 @@ Args: {
|
|||
# Agent should handle the error gracefully
|
||||
assert any(isinstance(resp, AgentResponse) for resp in responses)
|
||||
# The tool should have returned an error response that contains error info
|
||||
structured_query_call_args = mock_structured_client.structured_query.call_args[0]
|
||||
assert "table" in structured_query_call_args[0].lower() or "exist" in structured_query_call_args[0].lower()
|
||||
call_args = mock_structured_client.structured_query.call_args
|
||||
question_arg = call_args.kwargs.get("question") or call_args[1].get("question")
|
||||
assert "table" in question_arg.lower() or "exist" in question_arg.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_multi_step_structured_query_reasoning(self, agent_processor, structured_query_tool_config):
|
||||
|
|
@ -209,8 +215,10 @@ Args: {
|
|||
|
||||
request = AgentRequest(
|
||||
question="First find all customers from California, then tell me how many orders they have made.",
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
state="",
|
||||
group=None,
|
||||
history=[],
|
||||
user="test_user"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
|
|
@ -272,8 +280,9 @@ Args: {
|
|||
|
||||
assert any(isinstance(resp, AgentResponse) for resp in responses)
|
||||
# Verify the structured query was called with customer-related question
|
||||
call_args = mock_structured_client.structured_query.call_args[0]
|
||||
assert "california" in call_args[0].lower()
|
||||
call_args = mock_structured_client.structured_query.call_args
|
||||
question_arg = call_args.kwargs.get("question") or call_args[1].get("question")
|
||||
assert "california" in question_arg.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_structured_query_with_collection_parameter(self, agent_processor):
|
||||
|
|
@ -295,8 +304,10 @@ Args: {
|
|||
|
||||
request = AgentRequest(
|
||||
question="Query the sales data for recent transactions.",
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
state="",
|
||||
group=None,
|
||||
history=[],
|
||||
user="test_user"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
|
|
@ -359,8 +370,9 @@ Args: {
|
|||
|
||||
assert any(isinstance(resp, AgentResponse) for resp in responses)
|
||||
# Check the query was about sales/transactions
|
||||
call_args = mock_structured_client.structured_query.call_args[0]
|
||||
assert "sales" in call_args[0].lower() or "transactions" in call_args[0].lower()
|
||||
call_args = mock_structured_client.structured_query.call_args
|
||||
question_arg = call_args.kwargs.get("question") or call_args[1].get("question")
|
||||
assert "sales" in question_arg.lower() or "transactions" in question_arg.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_structured_query_tool_argument_validation(self, agent_processor, structured_query_tool_config):
|
||||
|
|
@ -390,8 +402,10 @@ Args: {
|
|||
|
||||
request = AgentRequest(
|
||||
question="Get customer information and format it nicely.",
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
state="",
|
||||
group=None,
|
||||
history=[],
|
||||
user="test_user"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
|
|
@ -463,5 +477,6 @@ Args: {
|
|||
assert any(isinstance(resp, AgentResponse) for resp in responses)
|
||||
|
||||
# Check that the query was about customer information
|
||||
call_args = mock_structured_client.structured_query.call_args[0]
|
||||
assert "customer" in call_args[0].lower()
|
||||
call_args = mock_structured_client.structured_query.call_args
|
||||
question_arg = call_args.kwargs.get("question") or call_args[1].get("question")
|
||||
assert "customer" in question_arg.lower()
|
||||
Loading…
Add table
Add a link
Reference in a new issue