Extend use of user + collection fields (#503)

* Collection+user fields in structured query

* User/collection in structured query & agent
This commit is contained in:
cybermaggedon 2025-09-08 18:28:38 +01:00 committed by GitHub
parent a92050c411
commit f22bf13aa6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 122 additions and 45 deletions

View file

@ -59,8 +59,10 @@ class TestAgentStructuredQueryIntegration:
# Create agent request # Create agent request
request = AgentRequest( request = AgentRequest(
question="I need to find all customers from New York. Use the structured query tool to get this information.", question="I need to find all customers from New York. Use the structured query tool to get this information.",
user="test_user", state="",
collection="test_collection" group=None,
history=[],
user="test_user"
) )
msg = MagicMock() msg = MagicMock()
@ -118,7 +120,8 @@ Args: {
# Verify structured query was called # Verify structured query was called
mock_structured_client.structured_query.assert_called_once() mock_structured_client.structured_query.assert_called_once()
call_args = mock_structured_client.structured_query.call_args call_args = mock_structured_client.structured_query.call_args
question_arg = call_args[0][0] # positional argument # Check keyword arguments
question_arg = call_args.kwargs.get("question") or call_args[1].get("question")
assert "customers" in question_arg.lower() assert "customers" in question_arg.lower()
assert "new york" in question_arg.lower() assert "new york" in question_arg.lower()
@ -140,8 +143,10 @@ Args: {
request = AgentRequest( request = AgentRequest(
question="Find data from a table that doesn't exist using structured query.", question="Find data from a table that doesn't exist using structured query.",
user="test_user", state="",
collection="test_collection" group=None,
history=[],
user="test_user"
) )
msg = MagicMock() msg = MagicMock()
@ -198,8 +203,9 @@ Args: {
# Agent should handle the error gracefully # Agent should handle the error gracefully
assert any(isinstance(resp, AgentResponse) for resp in responses) assert any(isinstance(resp, AgentResponse) for resp in responses)
# The tool should have returned an error response that contains error info # The tool should have returned an error response that contains error info
structured_query_call_args = mock_structured_client.structured_query.call_args[0] call_args = mock_structured_client.structured_query.call_args
assert "table" in structured_query_call_args[0].lower() or "exist" in structured_query_call_args[0].lower() question_arg = call_args.kwargs.get("question") or call_args[1].get("question")
assert "table" in question_arg.lower() or "exist" in question_arg.lower()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_agent_multi_step_structured_query_reasoning(self, agent_processor, structured_query_tool_config): async def test_agent_multi_step_structured_query_reasoning(self, agent_processor, structured_query_tool_config):
@ -209,8 +215,10 @@ Args: {
request = AgentRequest( request = AgentRequest(
question="First find all customers from California, then tell me how many orders they have made.", question="First find all customers from California, then tell me how many orders they have made.",
user="test_user", state="",
collection="test_collection" group=None,
history=[],
user="test_user"
) )
msg = MagicMock() msg = MagicMock()
@ -272,8 +280,9 @@ Args: {
assert any(isinstance(resp, AgentResponse) for resp in responses) assert any(isinstance(resp, AgentResponse) for resp in responses)
# Verify the structured query was called with customer-related question # Verify the structured query was called with customer-related question
call_args = mock_structured_client.structured_query.call_args[0] call_args = mock_structured_client.structured_query.call_args
assert "california" in call_args[0].lower() question_arg = call_args.kwargs.get("question") or call_args[1].get("question")
assert "california" in question_arg.lower()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_agent_structured_query_with_collection_parameter(self, agent_processor): async def test_agent_structured_query_with_collection_parameter(self, agent_processor):
@ -295,8 +304,10 @@ Args: {
request = AgentRequest( request = AgentRequest(
question="Query the sales data for recent transactions.", question="Query the sales data for recent transactions.",
user="test_user", state="",
collection="test_collection" group=None,
history=[],
user="test_user"
) )
msg = MagicMock() msg = MagicMock()
@ -359,8 +370,9 @@ Args: {
assert any(isinstance(resp, AgentResponse) for resp in responses) assert any(isinstance(resp, AgentResponse) for resp in responses)
# Check the query was about sales/transactions # Check the query was about sales/transactions
call_args = mock_structured_client.structured_query.call_args[0] call_args = mock_structured_client.structured_query.call_args
assert "sales" in call_args[0].lower() or "transactions" in call_args[0].lower() question_arg = call_args.kwargs.get("question") or call_args[1].get("question")
assert "sales" in question_arg.lower() or "transactions" in question_arg.lower()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_agent_structured_query_tool_argument_validation(self, agent_processor, structured_query_tool_config): async def test_agent_structured_query_tool_argument_validation(self, agent_processor, structured_query_tool_config):
@ -390,8 +402,10 @@ Args: {
request = AgentRequest( request = AgentRequest(
question="Get customer information and format it nicely.", question="Get customer information and format it nicely.",
user="test_user", state="",
collection="test_collection" group=None,
history=[],
user="test_user"
) )
msg = MagicMock() msg = MagicMock()
@ -463,5 +477,6 @@ Args: {
assert any(isinstance(resp, AgentResponse) for resp in responses) assert any(isinstance(resp, AgentResponse) for resp in responses)
# Check that the query was about customer information # Check that the query was about customer information
call_args = mock_structured_client.structured_query.call_args[0] call_args = mock_structured_client.structured_query.call_args
assert "customer" in call_args[0].lower() question_arg = call_args.kwargs.get("question") or call_args[1].get("question")
assert "customer" in question_arg.lower()

View file

@ -41,7 +41,9 @@ class TestStructuredQueryServiceIntegration:
"""Test complete structured query processing pipeline""" """Test complete structured query processing pipeline"""
# Arrange - Create realistic query request # Arrange - Create realistic query request
request = StructuredQueryRequest( request = StructuredQueryRequest(
question="Show me all customers from California who have made purchases over $500" question="Show me all customers from California who have made purchases over $500",
user="trustgraph",
collection="default"
) )
msg = MagicMock() msg = MagicMock()

View file

@ -44,7 +44,9 @@ class TestStructuredQueryProcessor:
"""Test successful end-to-end query processing""" """Test successful end-to-end query processing"""
# Arrange # Arrange
request = StructuredQueryRequest( request = StructuredQueryRequest(
question="Show me all customers from New York" question="Show me all customers from New York",
user="trustgraph",
collection="default"
) )
msg = MagicMock() msg = MagicMock()

View file

@ -132,11 +132,15 @@ class FlowInstance:
input input
)["response"] )["response"]
def agent(self, question): def agent(self, question, user="trustgraph", state="", group=None, history=None):
# The input consists of a question # The input consists of a question and optional context
input = { input = {
"question": question "question": question,
"user": user,
"state": state,
"group": group or [],
"history": history or []
} }
return self.request( return self.request(
@ -456,20 +460,24 @@ class FlowInstance:
return response return response
def structured_query(self, question): def structured_query(self, question, user="trustgraph", collection="default"):
""" """
Execute a natural language question against structured data. Execute a natural language question against structured data.
Combines NLP query conversion and GraphQL execution. Combines NLP query conversion and GraphQL execution.
Args: Args:
question: Natural language question question: Natural language question
user: Cassandra keyspace identifier (default: "trustgraph")
collection: Data collection identifier (default: "default")
Returns: Returns:
dict with data and optional errors dict with data and optional errors
""" """
input = { input = {
"question": question "question": question,
"user": user,
"collection": collection
} }
response = self.request( response = self.request(

View file

@ -2,10 +2,12 @@ from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import StructuredQueryRequest, StructuredQueryResponse from .. schema import StructuredQueryRequest, StructuredQueryResponse
class StructuredQueryClient(RequestResponse): class StructuredQueryClient(RequestResponse):
async def structured_query(self, question, timeout=600): async def structured_query(self, question, user="trustgraph", collection="default", timeout=600):
resp = await self.request( resp = await self.request(
StructuredQueryRequest( StructuredQueryRequest(
question = question question = question,
user = user,
collection = collection
), ),
timeout=timeout timeout=timeout
) )

View file

@ -9,17 +9,19 @@ class AgentRequestTranslator(MessageTranslator):
def to_pulsar(self, data: Dict[str, Any]) -> AgentRequest: def to_pulsar(self, data: Dict[str, Any]) -> AgentRequest:
return AgentRequest( return AgentRequest(
question=data["question"], question=data["question"],
plan=data.get("plan", ""),
state=data.get("state", ""), state=data.get("state", ""),
history=data.get("history", []) group=data.get("group", []),
history=data.get("history", []),
user=data.get("user", "trustgraph")
) )
def from_pulsar(self, obj: AgentRequest) -> Dict[str, Any]: def from_pulsar(self, obj: AgentRequest) -> Dict[str, Any]:
return { return {
"question": obj.question, "question": obj.question,
"plan": obj.plan,
"state": obj.state, "state": obj.state,
"history": obj.history "group": obj.group,
"history": obj.history,
"user": obj.user
} }

View file

@ -9,12 +9,16 @@ class StructuredQueryRequestTranslator(MessageTranslator):
def to_pulsar(self, data: Dict[str, Any]) -> StructuredQueryRequest: def to_pulsar(self, data: Dict[str, Any]) -> StructuredQueryRequest:
return StructuredQueryRequest( return StructuredQueryRequest(
question=data.get("question", "") question=data.get("question", ""),
user=data.get("user", "trustgraph"), # Default fallback
collection=data.get("collection", "default") # Default fallback
) )
def from_pulsar(self, obj: StructuredQueryRequest) -> Dict[str, Any]: def from_pulsar(self, obj: StructuredQueryRequest) -> Dict[str, Any]:
return { return {
"question": obj.question "question": obj.question,
"user": obj.user,
"collection": obj.collection
} }

View file

@ -13,12 +13,14 @@ class AgentStep(Record):
action = String() action = String()
arguments = Map(String()) arguments = Map(String())
observation = String() observation = String()
user = String() # User context for the step
class AgentRequest(Record): class AgentRequest(Record):
question = String() question = String()
state = String() state = String()
group = Array(String()) group = Array(String())
history = Array(AgentStep()) history = Array(AgentStep())
user = String() # User context for multi-tenancy
class AgentResponse(Record): class AgentResponse(Record):
answer = String() answer = String()

View file

@ -9,6 +9,8 @@ from ..core.topic import topic
class StructuredQueryRequest(Record): class StructuredQueryRequest(Record):
question = String() question = String()
user = String() # Cassandra keyspace identifier
collection = String() # Data collection identifier
class StructuredQueryResponse(Record): class StructuredQueryResponse(Record):
error = Error() error = Error()

View file

@ -61,6 +61,10 @@ async def question(
"flow": flow_id, "flow": flow_id,
"request": { "request": {
"question": question, "question": question,
"user": user,
"state": state or "",
"group": [],
"history": []
} }
}) })

View file

@ -79,11 +79,11 @@ def format_table_data(rows, table_name, output_format):
else: else:
return json.dumps({table_name: rows}, indent=2) return json.dumps({table_name: rows}, indent=2)
def structured_query(url, flow_id, question, output_format='table'): def structured_query(url, flow_id, question, user='trustgraph', collection='default', output_format='table'):
api = Api(url).flow().id(flow_id) api = Api(url).flow().id(flow_id)
resp = api.structured_query(question=question) resp = api.structured_query(question=question, user=user, collection=collection)
# Check for errors # Check for errors
if "error" in resp and resp["error"]: if "error" in resp and resp["error"]:
@ -132,6 +132,18 @@ def main():
help='Natural language question to execute', help='Natural language question to execute',
) )
parser.add_argument(
'--user',
default='trustgraph',
help='Cassandra keyspace identifier (default: trustgraph)'
)
parser.add_argument(
'--collection',
default='default',
help='Data collection identifier (default: default)'
)
parser.add_argument( parser.add_argument(
'--format', '--format',
choices=['table', 'json', 'csv'], choices=['table', 'json', 'csv'],
@ -147,6 +159,8 @@ def main():
url=args.url, url=args.url,
flow_id=args.flow_id, flow_id=args.flow_id,
question=args.question, question=args.question,
user=args.user,
collection=args.collection,
output_format=args.format, output_format=args.format,
) )

View file

@ -148,7 +148,8 @@ class Processor(AgentService):
elif impl_id == "structured-query": elif impl_id == "structured-query":
impl = functools.partial( impl = functools.partial(
StructuredQueryImpl, StructuredQueryImpl,
collection=data.get("collection") collection=data.get("collection"),
user=None # User will be provided dynamically via context
) )
arguments = StructuredQueryImpl.get_arguments() arguments = StructuredQueryImpl.get_arguments()
else: else:
@ -253,12 +254,26 @@ class Processor(AgentService):
logger.debug("Call React") logger.debug("Call React")
# Create user-aware context wrapper that preserves the flow interface
# but adds user information for tools that need it
class UserAwareContext:
def __init__(self, flow, user):
self._flow = flow
self._user = user
def __call__(self, service_name):
client = self._flow(service_name)
# For structured query clients, store user context
if service_name == "structured-query-request":
client._current_user = self._user
return client
act = await temp_agent.react( act = await temp_agent.react(
question = request.question, question = request.question,
history = history, history = history,
think = think, think = think,
observe = observe, observe = observe,
context = flow, context = UserAwareContext(flow, request.user),
) )
logger.debug(f"Action: {act}") logger.debug(f"Action: {act}")

View file

@ -87,9 +87,10 @@ class McpToolImpl:
# This tool implementation knows how to query structured data using natural language # This tool implementation knows how to query structured data using natural language
class StructuredQueryImpl: class StructuredQueryImpl:
def __init__(self, context, collection=None): def __init__(self, context, collection=None, user=None):
self.context = context self.context = context
self.collection = collection # For multi-tenant scenarios self.collection = collection # For multi-tenant scenarios
self.user = user # User context for multi-tenancy
@staticmethod @staticmethod
def get_arguments(): def get_arguments():
@ -105,8 +106,13 @@ class StructuredQueryImpl:
client = self.context("structured-query-request") client = self.context("structured-query-request")
logger.debug("Structured query question...") logger.debug("Structured query question...")
# Get user from client context if available, otherwise use instance user or default
user = getattr(client, '_current_user', self.user or "trustgraph")
result = await client.structured_query( result = await client.structured_query(
arguments.get("question") question=arguments.get("question"),
user=user,
collection=self.collection or "default"
) )
# Format the result for the agent # Format the result for the agent

View file

@ -31,7 +31,7 @@ def filter_tools_by_group_and_state(
# Apply defaults as specified in tech spec # Apply defaults as specified in tech spec
if requested_groups is None: if requested_groups is None:
requested_groups = ["default"] requested_groups = ["default"]
if current_state is None: if current_state is None or current_state == "":
current_state = "undefined" current_state = "undefined"
logger.info(f"Filtering tools with groups={requested_groups}, state={current_state}") logger.info(f"Filtering tools with groups={requested_groups}, state={current_state}")

View file

@ -111,11 +111,10 @@ class Processor(FlowProcessor):
else: else:
variables_as_strings[key] = str(value) variables_as_strings[key] = str(value)
# Use standard TrustGraph user/collection values # Use user/collection values from request
# These should eventually come from authentication/context
objects_request = ObjectsQueryRequest( objects_request = ObjectsQueryRequest(
user="trustgraph", # Standard TrustGraph user user=request.user,
collection="default", # Standard default collection collection=request.collection,
query=nlp_response.graphql_query, query=nlp_response.graphql_query,
variables=variables_as_strings, variables=variables_as_strings,
operation_name=None operation_name=None