Extend use of user + collection fields (#503)

* Collection+user fields in structured query

* User/collection in structured query & agent
This commit is contained in:
cybermaggedon 2025-09-08 18:28:38 +01:00 committed by GitHub
parent a92050c411
commit f22bf13aa6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 122 additions and 45 deletions

View file

@ -59,8 +59,10 @@ class TestAgentStructuredQueryIntegration:
# Create agent request
request = AgentRequest(
question="I need to find all customers from New York. Use the structured query tool to get this information.",
user="test_user",
collection="test_collection"
state="",
group=None,
history=[],
user="test_user"
)
msg = MagicMock()
@ -118,7 +120,8 @@ Args: {
# Verify structured query was called
mock_structured_client.structured_query.assert_called_once()
call_args = mock_structured_client.structured_query.call_args
question_arg = call_args[0][0] # positional argument
# Check keyword arguments
question_arg = call_args.kwargs.get("question") or call_args[1].get("question")
assert "customers" in question_arg.lower()
assert "new york" in question_arg.lower()
@ -140,8 +143,10 @@ Args: {
request = AgentRequest(
question="Find data from a table that doesn't exist using structured query.",
user="test_user",
collection="test_collection"
state="",
group=None,
history=[],
user="test_user"
)
msg = MagicMock()
@ -198,8 +203,9 @@ Args: {
# Agent should handle the error gracefully
assert any(isinstance(resp, AgentResponse) for resp in responses)
# The tool should have returned an error response that contains error info
structured_query_call_args = mock_structured_client.structured_query.call_args[0]
assert "table" in structured_query_call_args[0].lower() or "exist" in structured_query_call_args[0].lower()
call_args = mock_structured_client.structured_query.call_args
question_arg = call_args.kwargs.get("question") or call_args[1].get("question")
assert "table" in question_arg.lower() or "exist" in question_arg.lower()
@pytest.mark.asyncio
async def test_agent_multi_step_structured_query_reasoning(self, agent_processor, structured_query_tool_config):
@ -209,8 +215,10 @@ Args: {
request = AgentRequest(
question="First find all customers from California, then tell me how many orders they have made.",
user="test_user",
collection="test_collection"
state="",
group=None,
history=[],
user="test_user"
)
msg = MagicMock()
@ -272,8 +280,9 @@ Args: {
assert any(isinstance(resp, AgentResponse) for resp in responses)
# Verify the structured query was called with customer-related question
call_args = mock_structured_client.structured_query.call_args[0]
assert "california" in call_args[0].lower()
call_args = mock_structured_client.structured_query.call_args
question_arg = call_args.kwargs.get("question") or call_args[1].get("question")
assert "california" in question_arg.lower()
@pytest.mark.asyncio
async def test_agent_structured_query_with_collection_parameter(self, agent_processor):
@ -295,8 +304,10 @@ Args: {
request = AgentRequest(
question="Query the sales data for recent transactions.",
user="test_user",
collection="test_collection"
state="",
group=None,
history=[],
user="test_user"
)
msg = MagicMock()
@ -359,8 +370,9 @@ Args: {
assert any(isinstance(resp, AgentResponse) for resp in responses)
# Check the query was about sales/transactions
call_args = mock_structured_client.structured_query.call_args[0]
assert "sales" in call_args[0].lower() or "transactions" in call_args[0].lower()
call_args = mock_structured_client.structured_query.call_args
question_arg = call_args.kwargs.get("question") or call_args[1].get("question")
assert "sales" in question_arg.lower() or "transactions" in question_arg.lower()
@pytest.mark.asyncio
async def test_agent_structured_query_tool_argument_validation(self, agent_processor, structured_query_tool_config):
@ -390,8 +402,10 @@ Args: {
request = AgentRequest(
question="Get customer information and format it nicely.",
user="test_user",
collection="test_collection"
state="",
group=None,
history=[],
user="test_user"
)
msg = MagicMock()
@ -463,5 +477,6 @@ Args: {
assert any(isinstance(resp, AgentResponse) for resp in responses)
# Check that the query was about customer information
call_args = mock_structured_client.structured_query.call_args[0]
assert "customer" in call_args[0].lower()
call_args = mock_structured_client.structured_query.call_args
question_arg = call_args.kwargs.get("question") or call_args[1].get("question")
assert "customer" in question_arg.lower()

View file

@ -41,7 +41,9 @@ class TestStructuredQueryServiceIntegration:
"""Test complete structured query processing pipeline"""
# Arrange - Create realistic query request
request = StructuredQueryRequest(
question="Show me all customers from California who have made purchases over $500"
question="Show me all customers from California who have made purchases over $500",
user="trustgraph",
collection="default"
)
msg = MagicMock()

View file

@ -44,7 +44,9 @@ class TestStructuredQueryProcessor:
"""Test successful end-to-end query processing"""
# Arrange
request = StructuredQueryRequest(
question="Show me all customers from New York"
question="Show me all customers from New York",
user="trustgraph",
collection="default"
)
msg = MagicMock()

View file

@ -132,11 +132,15 @@ class FlowInstance:
input
)["response"]
def agent(self, question):
def agent(self, question, user="trustgraph", state="", group=None, history=None):
# The input consists of a question
# The input consists of a question and optional context
input = {
"question": question
"question": question,
"user": user,
"state": state,
"group": group or [],
"history": history or []
}
return self.request(
@ -456,20 +460,24 @@ class FlowInstance:
return response
def structured_query(self, question):
def structured_query(self, question, user="trustgraph", collection="default"):
"""
Execute a natural language question against structured data.
Combines NLP query conversion and GraphQL execution.
Args:
question: Natural language question
user: Cassandra keyspace identifier (default: "trustgraph")
collection: Data collection identifier (default: "default")
Returns:
dict with data and optional errors
"""
input = {
"question": question
"question": question,
"user": user,
"collection": collection
}
response = self.request(

View file

@ -2,10 +2,12 @@ from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import StructuredQueryRequest, StructuredQueryResponse
class StructuredQueryClient(RequestResponse):
async def structured_query(self, question, timeout=600):
async def structured_query(self, question, user="trustgraph", collection="default", timeout=600):
resp = await self.request(
StructuredQueryRequest(
question = question
question = question,
user = user,
collection = collection
),
timeout=timeout
)

View file

@ -9,17 +9,19 @@ class AgentRequestTranslator(MessageTranslator):
def to_pulsar(self, data: Dict[str, Any]) -> AgentRequest:
return AgentRequest(
question=data["question"],
plan=data.get("plan", ""),
state=data.get("state", ""),
history=data.get("history", [])
group=data.get("group", []),
history=data.get("history", []),
user=data.get("user", "trustgraph")
)
def from_pulsar(self, obj: AgentRequest) -> Dict[str, Any]:
return {
"question": obj.question,
"plan": obj.plan,
"state": obj.state,
"history": obj.history
"group": obj.group,
"history": obj.history,
"user": obj.user
}

View file

@ -9,12 +9,16 @@ class StructuredQueryRequestTranslator(MessageTranslator):
def to_pulsar(self, data: Dict[str, Any]) -> StructuredQueryRequest:
return StructuredQueryRequest(
question=data.get("question", "")
question=data.get("question", ""),
user=data.get("user", "trustgraph"), # Default fallback
collection=data.get("collection", "default") # Default fallback
)
def from_pulsar(self, obj: StructuredQueryRequest) -> Dict[str, Any]:
return {
"question": obj.question
"question": obj.question,
"user": obj.user,
"collection": obj.collection
}

View file

@ -13,12 +13,14 @@ class AgentStep(Record):
action = String()
arguments = Map(String())
observation = String()
user = String() # User context for the step
class AgentRequest(Record):
question = String()
state = String()
group = Array(String())
history = Array(AgentStep())
user = String() # User context for multi-tenancy
class AgentResponse(Record):
answer = String()

View file

@ -9,6 +9,8 @@ from ..core.topic import topic
class StructuredQueryRequest(Record):
question = String()
user = String() # Cassandra keyspace identifier
collection = String() # Data collection identifier
class StructuredQueryResponse(Record):
error = Error()

View file

@ -61,6 +61,10 @@ async def question(
"flow": flow_id,
"request": {
"question": question,
"user": user,
"state": state or "",
"group": [],
"history": []
}
})

View file

@ -79,11 +79,11 @@ def format_table_data(rows, table_name, output_format):
else:
return json.dumps({table_name: rows}, indent=2)
def structured_query(url, flow_id, question, output_format='table'):
def structured_query(url, flow_id, question, user='trustgraph', collection='default', output_format='table'):
api = Api(url).flow().id(flow_id)
resp = api.structured_query(question=question)
resp = api.structured_query(question=question, user=user, collection=collection)
# Check for errors
if "error" in resp and resp["error"]:
@ -132,6 +132,18 @@ def main():
help='Natural language question to execute',
)
parser.add_argument(
'--user',
default='trustgraph',
help='Cassandra keyspace identifier (default: trustgraph)'
)
parser.add_argument(
'--collection',
default='default',
help='Data collection identifier (default: default)'
)
parser.add_argument(
'--format',
choices=['table', 'json', 'csv'],
@ -147,6 +159,8 @@ def main():
url=args.url,
flow_id=args.flow_id,
question=args.question,
user=args.user,
collection=args.collection,
output_format=args.format,
)

View file

@ -148,7 +148,8 @@ class Processor(AgentService):
elif impl_id == "structured-query":
impl = functools.partial(
StructuredQueryImpl,
collection=data.get("collection")
collection=data.get("collection"),
user=None # User will be provided dynamically via context
)
arguments = StructuredQueryImpl.get_arguments()
else:
@ -253,12 +254,26 @@ class Processor(AgentService):
logger.debug("Call React")
# Create user-aware context wrapper that preserves the flow interface
# but adds user information for tools that need it
class UserAwareContext:
def __init__(self, flow, user):
self._flow = flow
self._user = user
def __call__(self, service_name):
client = self._flow(service_name)
# For structured query clients, store user context
if service_name == "structured-query-request":
client._current_user = self._user
return client
act = await temp_agent.react(
question = request.question,
history = history,
think = think,
observe = observe,
context = flow,
context = UserAwareContext(flow, request.user),
)
logger.debug(f"Action: {act}")

View file

@ -87,9 +87,10 @@ class McpToolImpl:
# This tool implementation knows how to query structured data using natural language
class StructuredQueryImpl:
def __init__(self, context, collection=None):
def __init__(self, context, collection=None, user=None):
self.context = context
self.collection = collection # For multi-tenant scenarios
self.user = user # User context for multi-tenancy
@staticmethod
def get_arguments():
@ -105,8 +106,13 @@ class StructuredQueryImpl:
client = self.context("structured-query-request")
logger.debug("Structured query question...")
# Get user from client context if available, otherwise use instance user or default
user = getattr(client, '_current_user', self.user or "trustgraph")
result = await client.structured_query(
arguments.get("question")
question=arguments.get("question"),
user=user,
collection=self.collection or "default"
)
# Format the result for the agent

View file

@ -31,7 +31,7 @@ def filter_tools_by_group_and_state(
# Apply defaults as specified in tech spec
if requested_groups is None:
requested_groups = ["default"]
if current_state is None:
if current_state is None or current_state == "":
current_state = "undefined"
logger.info(f"Filtering tools with groups={requested_groups}, state={current_state}")

View file

@ -111,11 +111,10 @@ class Processor(FlowProcessor):
else:
variables_as_strings[key] = str(value)
# Use standard TrustGraph user/collection values
# These should eventually come from authentication/context
# Use user/collection values from request
objects_request = ObjectsQueryRequest(
user="trustgraph", # Standard TrustGraph user
collection="default", # Standard default collection
user=request.user,
collection=request.collection,
query=nlp_response.graphql_query,
variables=variables_as_strings,
operation_name=None