diff --git a/tests/integration/test_agent_structured_query_integration.py b/tests/integration/test_agent_structured_query_integration.py index 38161292..f4f59444 100644 --- a/tests/integration/test_agent_structured_query_integration.py +++ b/tests/integration/test_agent_structured_query_integration.py @@ -59,8 +59,10 @@ class TestAgentStructuredQueryIntegration: # Create agent request request = AgentRequest( question="I need to find all customers from New York. Use the structured query tool to get this information.", - user="test_user", - collection="test_collection" + state="", + group=None, + history=[], + user="test_user" ) msg = MagicMock() @@ -118,7 +120,8 @@ Args: { # Verify structured query was called mock_structured_client.structured_query.assert_called_once() call_args = mock_structured_client.structured_query.call_args - question_arg = call_args[0][0] # positional argument + # Check keyword arguments + question_arg = call_args.kwargs.get("question") or call_args[1].get("question") assert "customers" in question_arg.lower() assert "new york" in question_arg.lower() @@ -140,8 +143,10 @@ Args: { request = AgentRequest( question="Find data from a table that doesn't exist using structured query.", - user="test_user", - collection="test_collection" + state="", + group=None, + history=[], + user="test_user" ) msg = MagicMock() @@ -198,8 +203,9 @@ Args: { # Agent should handle the error gracefully assert any(isinstance(resp, AgentResponse) for resp in responses) # The tool should have returned an error response that contains error info - structured_query_call_args = mock_structured_client.structured_query.call_args[0] - assert "table" in structured_query_call_args[0].lower() or "exist" in structured_query_call_args[0].lower() + call_args = mock_structured_client.structured_query.call_args + question_arg = call_args.kwargs.get("question") or call_args[1].get("question") + assert "table" in question_arg.lower() or "exist" in question_arg.lower() @pytest.mark.asyncio async def test_agent_multi_step_structured_query_reasoning(self, agent_processor, structured_query_tool_config): @@ -209,8 +215,10 @@ Args: { request = AgentRequest( question="First find all customers from California, then tell me how many orders they have made.", - user="test_user", - collection="test_collection" + state="", + group=None, + history=[], + user="test_user" ) msg = MagicMock() @@ -272,8 +280,9 @@ Args: { assert any(isinstance(resp, AgentResponse) for resp in responses) # Verify the structured query was called with customer-related question - call_args = mock_structured_client.structured_query.call_args[0] - assert "california" in call_args[0].lower() + call_args = mock_structured_client.structured_query.call_args + question_arg = call_args.kwargs.get("question") or call_args[1].get("question") + assert "california" in question_arg.lower() @pytest.mark.asyncio async def test_agent_structured_query_with_collection_parameter(self, agent_processor): @@ -295,8 +304,10 @@ Args: { request = AgentRequest( question="Query the sales data for recent transactions.", - user="test_user", - collection="test_collection" + state="", + group=None, + history=[], + user="test_user" ) msg = MagicMock() @@ -359,8 +370,9 @@ Args: { assert any(isinstance(resp, AgentResponse) for resp in responses) # Check the query was about sales/transactions - call_args = mock_structured_client.structured_query.call_args[0] - assert "sales" in call_args[0].lower() or "transactions" in call_args[0].lower() + call_args = mock_structured_client.structured_query.call_args + question_arg = call_args.kwargs.get("question") or call_args[1].get("question") + assert "sales" in question_arg.lower() or "transactions" in question_arg.lower() @pytest.mark.asyncio async def test_agent_structured_query_tool_argument_validation(self, agent_processor, structured_query_tool_config): @@ -390,8 +402,10 @@ Args: { request = AgentRequest( question="Get customer information and format it nicely.", - user="test_user", - collection="test_collection" + state="", + group=None, + history=[], + user="test_user" ) msg = MagicMock() @@ -463,5 +477,6 @@ Args: { assert any(isinstance(resp, AgentResponse) for resp in responses) # Check that the query was about customer information - call_args = mock_structured_client.structured_query.call_args[0] - assert "customer" in call_args[0].lower() \ No newline at end of file + call_args = mock_structured_client.structured_query.call_args + question_arg = call_args.kwargs.get("question") or call_args[1].get("question") + assert "customer" in question_arg.lower() \ No newline at end of file diff --git a/tests/integration/test_structured_query_integration.py b/tests/integration/test_structured_query_integration.py index 2e836b33..cf8037d0 100644 --- a/tests/integration/test_structured_query_integration.py +++ b/tests/integration/test_structured_query_integration.py @@ -41,7 +41,9 @@ class TestStructuredQueryServiceIntegration: """Test complete structured query processing pipeline""" # Arrange - Create realistic query request request = StructuredQueryRequest( - question="Show me all customers from California who have made purchases over $500" + question="Show me all customers from California who have made purchases over $500", + user="trustgraph", + collection="default" ) msg = MagicMock() diff --git a/tests/unit/test_retrieval/test_structured_query.py b/tests/unit/test_retrieval/test_structured_query.py index 1e78fa97..27c09ca4 100644 --- a/tests/unit/test_retrieval/test_structured_query.py +++ b/tests/unit/test_retrieval/test_structured_query.py @@ -44,7 +44,9 @@ class TestStructuredQueryProcessor: """Test successful end-to-end query processing""" # Arrange request = StructuredQueryRequest( - question="Show me all customers from New York" + question="Show me all customers from New York", + user="trustgraph", + collection="default" ) msg = MagicMock() diff --git a/trustgraph-base/trustgraph/api/flow.py b/trustgraph-base/trustgraph/api/flow.py index b3c86c4e..681696c3 100644 --- a/trustgraph-base/trustgraph/api/flow.py +++ b/trustgraph-base/trustgraph/api/flow.py @@ -132,11 +132,15 @@ class FlowInstance: input )["response"] - def agent(self, question): + def agent(self, question, user="trustgraph", state="", group=None, history=None): - # The input consists of a question + # The input consists of a question and optional context input = { - "question": question + "question": question, + "user": user, + "state": state, + "group": group or [], + "history": history or [] } return self.request( @@ -456,20 +460,24 @@ class FlowInstance: return response - def structured_query(self, question): + def structured_query(self, question, user="trustgraph", collection="default"): """ Execute a natural language question against structured data. Combines NLP query conversion and GraphQL execution. Args: question: Natural language question + user: Cassandra keyspace identifier (default: "trustgraph") + collection: Data collection identifier (default: "default") Returns: dict with data and optional errors """ input = { - "question": question + "question": question, + "user": user, + "collection": collection } response = self.request( diff --git a/trustgraph-base/trustgraph/base/structured_query_client.py b/trustgraph-base/trustgraph/base/structured_query_client.py index dc025c4a..84d6bff3 100644 --- a/trustgraph-base/trustgraph/base/structured_query_client.py +++ b/trustgraph-base/trustgraph/base/structured_query_client.py @@ -2,10 +2,12 @@ from . request_response_spec import RequestResponse, RequestResponseSpec from .. schema import StructuredQueryRequest, StructuredQueryResponse class StructuredQueryClient(RequestResponse): - async def structured_query(self, question, timeout=600): + async def structured_query(self, question, user="trustgraph", collection="default", timeout=600): resp = await self.request( StructuredQueryRequest( - question = question + question = question, + user = user, + collection = collection ), timeout=timeout ) diff --git a/trustgraph-base/trustgraph/messaging/translators/agent.py b/trustgraph-base/trustgraph/messaging/translators/agent.py index 5529a1a2..4408fea3 100644 --- a/trustgraph-base/trustgraph/messaging/translators/agent.py +++ b/trustgraph-base/trustgraph/messaging/translators/agent.py @@ -9,17 +9,19 @@ class AgentRequestTranslator(MessageTranslator): def to_pulsar(self, data: Dict[str, Any]) -> AgentRequest: return AgentRequest( question=data["question"], - plan=data.get("plan", ""), state=data.get("state", ""), - history=data.get("history", []) + group=data.get("group", []), + history=data.get("history", []), + user=data.get("user", "trustgraph") ) def from_pulsar(self, obj: AgentRequest) -> Dict[str, Any]: return { "question": obj.question, - "plan": obj.plan, "state": obj.state, - "history": obj.history + "group": obj.group, + "history": obj.history, + "user": obj.user } diff --git a/trustgraph-base/trustgraph/messaging/translators/structured_query.py b/trustgraph-base/trustgraph/messaging/translators/structured_query.py index c6a8abc8..cc3ae80c 100644 --- a/trustgraph-base/trustgraph/messaging/translators/structured_query.py +++ b/trustgraph-base/trustgraph/messaging/translators/structured_query.py @@ -9,12 +9,16 @@ class StructuredQueryRequestTranslator(MessageTranslator): def to_pulsar(self, data: Dict[str, Any]) -> StructuredQueryRequest: return StructuredQueryRequest( - question=data.get("question", "") + question=data.get("question", ""), + user=data.get("user", "trustgraph"), # Default fallback + collection=data.get("collection", "default") # Default fallback ) def from_pulsar(self, obj: StructuredQueryRequest) -> Dict[str, Any]: return { - "question": obj.question + "question": obj.question, + "user": obj.user, + "collection": obj.collection } diff --git a/trustgraph-base/trustgraph/schema/services/agent.py b/trustgraph-base/trustgraph/schema/services/agent.py index 55f2ae0f..c9b152b4 100644 --- a/trustgraph-base/trustgraph/schema/services/agent.py +++ b/trustgraph-base/trustgraph/schema/services/agent.py @@ -13,12 +13,14 @@ class AgentStep(Record): action = String() arguments = Map(String()) observation = String() + user = String() # User context for the step class AgentRequest(Record): question = String() state = String() group = Array(String()) history = Array(AgentStep()) + user = String() # User context for multi-tenancy class AgentResponse(Record): answer = String() diff --git a/trustgraph-base/trustgraph/schema/services/structured_query.py b/trustgraph-base/trustgraph/schema/services/structured_query.py index 537fc36b..df21bfe2 100644 --- a/trustgraph-base/trustgraph/schema/services/structured_query.py +++ b/trustgraph-base/trustgraph/schema/services/structured_query.py @@ -9,6 +9,8 @@ from ..core.topic import topic class StructuredQueryRequest(Record): question = String() + user = String() # Cassandra keyspace identifier + collection = String() # Data collection identifier class StructuredQueryResponse(Record): error = Error() diff --git a/trustgraph-cli/trustgraph/cli/invoke_agent.py b/trustgraph-cli/trustgraph/cli/invoke_agent.py index 4b861919..c5ca93e4 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_agent.py +++ b/trustgraph-cli/trustgraph/cli/invoke_agent.py @@ -61,6 +61,10 @@ async def question( "flow": flow_id, "request": { "question": question, + "user": user, + "state": state or "", + "group": [], + "history": [] } }) diff --git a/trustgraph-cli/trustgraph/cli/invoke_structured_query.py b/trustgraph-cli/trustgraph/cli/invoke_structured_query.py index 8f34e747..9f5f8540 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_structured_query.py +++ b/trustgraph-cli/trustgraph/cli/invoke_structured_query.py @@ -79,11 +79,11 @@ def format_table_data(rows, table_name, output_format): else: return json.dumps({table_name: rows}, indent=2) -def structured_query(url, flow_id, question, output_format='table'): +def structured_query(url, flow_id, question, user='trustgraph', collection='default', output_format='table'): api = Api(url).flow().id(flow_id) - resp = api.structured_query(question=question) + resp = api.structured_query(question=question, user=user, collection=collection) # Check for errors if "error" in resp and resp["error"]: @@ -132,6 +132,18 @@ def main(): help='Natural language question to execute', ) + parser.add_argument( + '--user', + default='trustgraph', + help='Cassandra keyspace identifier (default: trustgraph)' + ) + + parser.add_argument( + '--collection', + default='default', + help='Data collection identifier (default: default)' + ) + parser.add_argument( '--format', choices=['table', 'json', 'csv'], @@ -147,6 +159,8 @@ def main(): url=args.url, flow_id=args.flow_id, question=args.question, + user=args.user, + collection=args.collection, output_format=args.format, ) diff --git a/trustgraph-flow/trustgraph/agent/react/service.py b/trustgraph-flow/trustgraph/agent/react/service.py index 55f8ce45..06bf7610 100755 --- a/trustgraph-flow/trustgraph/agent/react/service.py +++ b/trustgraph-flow/trustgraph/agent/react/service.py @@ -148,7 +148,8 @@ class Processor(AgentService): elif impl_id == "structured-query": impl = functools.partial( StructuredQueryImpl, - collection=data.get("collection") + collection=data.get("collection"), + user=None # User will be provided dynamically via context ) arguments = StructuredQueryImpl.get_arguments() else: @@ -253,12 +254,26 @@ class Processor(AgentService): logger.debug("Call React") + # Create user-aware context wrapper that preserves the flow interface + # but adds user information for tools that need it + class UserAwareContext: + def __init__(self, flow, user): + self._flow = flow + self._user = user + + def __call__(self, service_name): + client = self._flow(service_name) + # For structured query clients, store user context + if service_name == "structured-query-request": + client._current_user = self._user + return client + act = await temp_agent.react( question = request.question, history = history, think = think, observe = observe, - context = flow, + context = UserAwareContext(flow, request.user), ) logger.debug(f"Action: {act}") diff --git a/trustgraph-flow/trustgraph/agent/react/tools.py b/trustgraph-flow/trustgraph/agent/react/tools.py index 596df741..e32dc2d8 100644 --- a/trustgraph-flow/trustgraph/agent/react/tools.py +++ b/trustgraph-flow/trustgraph/agent/react/tools.py @@ -87,9 +87,10 @@ class McpToolImpl: # This tool implementation knows how to query structured data using natural language class StructuredQueryImpl: - def __init__(self, context, collection=None): + def __init__(self, context, collection=None, user=None): self.context = context self.collection = collection # For multi-tenant scenarios + self.user = user # User context for multi-tenancy @staticmethod def get_arguments(): @@ -105,8 +106,13 @@ class StructuredQueryImpl: client = self.context("structured-query-request") logger.debug("Structured query question...") + # Get user from client context if available, otherwise use instance user or default + user = getattr(client, '_current_user', self.user or "trustgraph") + result = await client.structured_query( - arguments.get("question") + question=arguments.get("question"), + user=user, + collection=self.collection or "default" ) # Format the result for the agent diff --git a/trustgraph-flow/trustgraph/agent/tool_filter.py b/trustgraph-flow/trustgraph/agent/tool_filter.py index 0d66b990..d1bac3e4 100644 --- a/trustgraph-flow/trustgraph/agent/tool_filter.py +++ b/trustgraph-flow/trustgraph/agent/tool_filter.py @@ -31,7 +31,7 @@ def filter_tools_by_group_and_state( # Apply defaults as specified in tech spec if requested_groups is None: requested_groups = ["default"] - if current_state is None: + if current_state is None or current_state == "": current_state = "undefined" logger.info(f"Filtering tools with groups={requested_groups}, state={current_state}") diff --git a/trustgraph-flow/trustgraph/retrieval/structured_query/service.py b/trustgraph-flow/trustgraph/retrieval/structured_query/service.py index 28327f82..4b1a04a4 100644 --- a/trustgraph-flow/trustgraph/retrieval/structured_query/service.py +++ b/trustgraph-flow/trustgraph/retrieval/structured_query/service.py @@ -111,11 +111,10 @@ class Processor(FlowProcessor): else: variables_as_strings[key] = str(value) - # Use standard TrustGraph user/collection values - # These should eventually come from authentication/context + # Use user/collection values from request objects_request = ObjectsQueryRequest( - user="trustgraph", # Standard TrustGraph user - collection="default", # Standard default collection + user=request.user, + collection=request.collection, query=nlp_response.graphql_query, variables=variables_as_strings, operation_name=None