mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-07-05 11:22:11 +02:00
Extend use of user + collection fields (#503)
* Collection+user fields in structured query * User/collection in structured query & agent
This commit is contained in:
parent
a92050c411
commit
f22bf13aa6
15 changed files with 122 additions and 45 deletions
|
|
@ -59,8 +59,10 @@ class TestAgentStructuredQueryIntegration:
|
||||||
# Create agent request
|
# Create agent request
|
||||||
request = AgentRequest(
|
request = AgentRequest(
|
||||||
question="I need to find all customers from New York. Use the structured query tool to get this information.",
|
question="I need to find all customers from New York. Use the structured query tool to get this information.",
|
||||||
user="test_user",
|
state="",
|
||||||
collection="test_collection"
|
group=None,
|
||||||
|
history=[],
|
||||||
|
user="test_user"
|
||||||
)
|
)
|
||||||
|
|
||||||
msg = MagicMock()
|
msg = MagicMock()
|
||||||
|
|
@ -118,7 +120,8 @@ Args: {
|
||||||
# Verify structured query was called
|
# Verify structured query was called
|
||||||
mock_structured_client.structured_query.assert_called_once()
|
mock_structured_client.structured_query.assert_called_once()
|
||||||
call_args = mock_structured_client.structured_query.call_args
|
call_args = mock_structured_client.structured_query.call_args
|
||||||
question_arg = call_args[0][0] # positional argument
|
# Check keyword arguments
|
||||||
|
question_arg = call_args.kwargs.get("question") or call_args[1].get("question")
|
||||||
assert "customers" in question_arg.lower()
|
assert "customers" in question_arg.lower()
|
||||||
assert "new york" in question_arg.lower()
|
assert "new york" in question_arg.lower()
|
||||||
|
|
||||||
|
|
@ -140,8 +143,10 @@ Args: {
|
||||||
|
|
||||||
request = AgentRequest(
|
request = AgentRequest(
|
||||||
question="Find data from a table that doesn't exist using structured query.",
|
question="Find data from a table that doesn't exist using structured query.",
|
||||||
user="test_user",
|
state="",
|
||||||
collection="test_collection"
|
group=None,
|
||||||
|
history=[],
|
||||||
|
user="test_user"
|
||||||
)
|
)
|
||||||
|
|
||||||
msg = MagicMock()
|
msg = MagicMock()
|
||||||
|
|
@ -198,8 +203,9 @@ Args: {
|
||||||
# Agent should handle the error gracefully
|
# Agent should handle the error gracefully
|
||||||
assert any(isinstance(resp, AgentResponse) for resp in responses)
|
assert any(isinstance(resp, AgentResponse) for resp in responses)
|
||||||
# The tool should have returned an error response that contains error info
|
# The tool should have returned an error response that contains error info
|
||||||
structured_query_call_args = mock_structured_client.structured_query.call_args[0]
|
call_args = mock_structured_client.structured_query.call_args
|
||||||
assert "table" in structured_query_call_args[0].lower() or "exist" in structured_query_call_args[0].lower()
|
question_arg = call_args.kwargs.get("question") or call_args[1].get("question")
|
||||||
|
assert "table" in question_arg.lower() or "exist" in question_arg.lower()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_agent_multi_step_structured_query_reasoning(self, agent_processor, structured_query_tool_config):
|
async def test_agent_multi_step_structured_query_reasoning(self, agent_processor, structured_query_tool_config):
|
||||||
|
|
@ -209,8 +215,10 @@ Args: {
|
||||||
|
|
||||||
request = AgentRequest(
|
request = AgentRequest(
|
||||||
question="First find all customers from California, then tell me how many orders they have made.",
|
question="First find all customers from California, then tell me how many orders they have made.",
|
||||||
user="test_user",
|
state="",
|
||||||
collection="test_collection"
|
group=None,
|
||||||
|
history=[],
|
||||||
|
user="test_user"
|
||||||
)
|
)
|
||||||
|
|
||||||
msg = MagicMock()
|
msg = MagicMock()
|
||||||
|
|
@ -272,8 +280,9 @@ Args: {
|
||||||
|
|
||||||
assert any(isinstance(resp, AgentResponse) for resp in responses)
|
assert any(isinstance(resp, AgentResponse) for resp in responses)
|
||||||
# Verify the structured query was called with customer-related question
|
# Verify the structured query was called with customer-related question
|
||||||
call_args = mock_structured_client.structured_query.call_args[0]
|
call_args = mock_structured_client.structured_query.call_args
|
||||||
assert "california" in call_args[0].lower()
|
question_arg = call_args.kwargs.get("question") or call_args[1].get("question")
|
||||||
|
assert "california" in question_arg.lower()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_agent_structured_query_with_collection_parameter(self, agent_processor):
|
async def test_agent_structured_query_with_collection_parameter(self, agent_processor):
|
||||||
|
|
@ -295,8 +304,10 @@ Args: {
|
||||||
|
|
||||||
request = AgentRequest(
|
request = AgentRequest(
|
||||||
question="Query the sales data for recent transactions.",
|
question="Query the sales data for recent transactions.",
|
||||||
user="test_user",
|
state="",
|
||||||
collection="test_collection"
|
group=None,
|
||||||
|
history=[],
|
||||||
|
user="test_user"
|
||||||
)
|
)
|
||||||
|
|
||||||
msg = MagicMock()
|
msg = MagicMock()
|
||||||
|
|
@ -359,8 +370,9 @@ Args: {
|
||||||
|
|
||||||
assert any(isinstance(resp, AgentResponse) for resp in responses)
|
assert any(isinstance(resp, AgentResponse) for resp in responses)
|
||||||
# Check the query was about sales/transactions
|
# Check the query was about sales/transactions
|
||||||
call_args = mock_structured_client.structured_query.call_args[0]
|
call_args = mock_structured_client.structured_query.call_args
|
||||||
assert "sales" in call_args[0].lower() or "transactions" in call_args[0].lower()
|
question_arg = call_args.kwargs.get("question") or call_args[1].get("question")
|
||||||
|
assert "sales" in question_arg.lower() or "transactions" in question_arg.lower()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_agent_structured_query_tool_argument_validation(self, agent_processor, structured_query_tool_config):
|
async def test_agent_structured_query_tool_argument_validation(self, agent_processor, structured_query_tool_config):
|
||||||
|
|
@ -390,8 +402,10 @@ Args: {
|
||||||
|
|
||||||
request = AgentRequest(
|
request = AgentRequest(
|
||||||
question="Get customer information and format it nicely.",
|
question="Get customer information and format it nicely.",
|
||||||
user="test_user",
|
state="",
|
||||||
collection="test_collection"
|
group=None,
|
||||||
|
history=[],
|
||||||
|
user="test_user"
|
||||||
)
|
)
|
||||||
|
|
||||||
msg = MagicMock()
|
msg = MagicMock()
|
||||||
|
|
@ -463,5 +477,6 @@ Args: {
|
||||||
assert any(isinstance(resp, AgentResponse) for resp in responses)
|
assert any(isinstance(resp, AgentResponse) for resp in responses)
|
||||||
|
|
||||||
# Check that the query was about customer information
|
# Check that the query was about customer information
|
||||||
call_args = mock_structured_client.structured_query.call_args[0]
|
call_args = mock_structured_client.structured_query.call_args
|
||||||
assert "customer" in call_args[0].lower()
|
question_arg = call_args.kwargs.get("question") or call_args[1].get("question")
|
||||||
|
assert "customer" in question_arg.lower()
|
||||||
|
|
@ -41,7 +41,9 @@ class TestStructuredQueryServiceIntegration:
|
||||||
"""Test complete structured query processing pipeline"""
|
"""Test complete structured query processing pipeline"""
|
||||||
# Arrange - Create realistic query request
|
# Arrange - Create realistic query request
|
||||||
request = StructuredQueryRequest(
|
request = StructuredQueryRequest(
|
||||||
question="Show me all customers from California who have made purchases over $500"
|
question="Show me all customers from California who have made purchases over $500",
|
||||||
|
user="trustgraph",
|
||||||
|
collection="default"
|
||||||
)
|
)
|
||||||
|
|
||||||
msg = MagicMock()
|
msg = MagicMock()
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,9 @@ class TestStructuredQueryProcessor:
|
||||||
"""Test successful end-to-end query processing"""
|
"""Test successful end-to-end query processing"""
|
||||||
# Arrange
|
# Arrange
|
||||||
request = StructuredQueryRequest(
|
request = StructuredQueryRequest(
|
||||||
question="Show me all customers from New York"
|
question="Show me all customers from New York",
|
||||||
|
user="trustgraph",
|
||||||
|
collection="default"
|
||||||
)
|
)
|
||||||
|
|
||||||
msg = MagicMock()
|
msg = MagicMock()
|
||||||
|
|
|
||||||
|
|
@ -132,11 +132,15 @@ class FlowInstance:
|
||||||
input
|
input
|
||||||
)["response"]
|
)["response"]
|
||||||
|
|
||||||
def agent(self, question):
|
def agent(self, question, user="trustgraph", state="", group=None, history=None):
|
||||||
|
|
||||||
# The input consists of a question
|
# The input consists of a question and optional context
|
||||||
input = {
|
input = {
|
||||||
"question": question
|
"question": question,
|
||||||
|
"user": user,
|
||||||
|
"state": state,
|
||||||
|
"group": group or [],
|
||||||
|
"history": history or []
|
||||||
}
|
}
|
||||||
|
|
||||||
return self.request(
|
return self.request(
|
||||||
|
|
@ -456,20 +460,24 @@ class FlowInstance:
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def structured_query(self, question):
|
def structured_query(self, question, user="trustgraph", collection="default"):
|
||||||
"""
|
"""
|
||||||
Execute a natural language question against structured data.
|
Execute a natural language question against structured data.
|
||||||
Combines NLP query conversion and GraphQL execution.
|
Combines NLP query conversion and GraphQL execution.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
question: Natural language question
|
question: Natural language question
|
||||||
|
user: Cassandra keyspace identifier (default: "trustgraph")
|
||||||
|
collection: Data collection identifier (default: "default")
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict with data and optional errors
|
dict with data and optional errors
|
||||||
"""
|
"""
|
||||||
|
|
||||||
input = {
|
input = {
|
||||||
"question": question
|
"question": question,
|
||||||
|
"user": user,
|
||||||
|
"collection": collection
|
||||||
}
|
}
|
||||||
|
|
||||||
response = self.request(
|
response = self.request(
|
||||||
|
|
|
||||||
|
|
@ -2,10 +2,12 @@ from . request_response_spec import RequestResponse, RequestResponseSpec
|
||||||
from .. schema import StructuredQueryRequest, StructuredQueryResponse
|
from .. schema import StructuredQueryRequest, StructuredQueryResponse
|
||||||
|
|
||||||
class StructuredQueryClient(RequestResponse):
|
class StructuredQueryClient(RequestResponse):
|
||||||
async def structured_query(self, question, timeout=600):
|
async def structured_query(self, question, user="trustgraph", collection="default", timeout=600):
|
||||||
resp = await self.request(
|
resp = await self.request(
|
||||||
StructuredQueryRequest(
|
StructuredQueryRequest(
|
||||||
question = question
|
question = question,
|
||||||
|
user = user,
|
||||||
|
collection = collection
|
||||||
),
|
),
|
||||||
timeout=timeout
|
timeout=timeout
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -9,17 +9,19 @@ class AgentRequestTranslator(MessageTranslator):
|
||||||
def to_pulsar(self, data: Dict[str, Any]) -> AgentRequest:
|
def to_pulsar(self, data: Dict[str, Any]) -> AgentRequest:
|
||||||
return AgentRequest(
|
return AgentRequest(
|
||||||
question=data["question"],
|
question=data["question"],
|
||||||
plan=data.get("plan", ""),
|
|
||||||
state=data.get("state", ""),
|
state=data.get("state", ""),
|
||||||
history=data.get("history", [])
|
group=data.get("group", []),
|
||||||
|
history=data.get("history", []),
|
||||||
|
user=data.get("user", "trustgraph")
|
||||||
)
|
)
|
||||||
|
|
||||||
def from_pulsar(self, obj: AgentRequest) -> Dict[str, Any]:
|
def from_pulsar(self, obj: AgentRequest) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"question": obj.question,
|
"question": obj.question,
|
||||||
"plan": obj.plan,
|
|
||||||
"state": obj.state,
|
"state": obj.state,
|
||||||
"history": obj.history
|
"group": obj.group,
|
||||||
|
"history": obj.history,
|
||||||
|
"user": obj.user
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,12 +9,16 @@ class StructuredQueryRequestTranslator(MessageTranslator):
|
||||||
|
|
||||||
def to_pulsar(self, data: Dict[str, Any]) -> StructuredQueryRequest:
|
def to_pulsar(self, data: Dict[str, Any]) -> StructuredQueryRequest:
|
||||||
return StructuredQueryRequest(
|
return StructuredQueryRequest(
|
||||||
question=data.get("question", "")
|
question=data.get("question", ""),
|
||||||
|
user=data.get("user", "trustgraph"), # Default fallback
|
||||||
|
collection=data.get("collection", "default") # Default fallback
|
||||||
)
|
)
|
||||||
|
|
||||||
def from_pulsar(self, obj: StructuredQueryRequest) -> Dict[str, Any]:
|
def from_pulsar(self, obj: StructuredQueryRequest) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"question": obj.question
|
"question": obj.question,
|
||||||
|
"user": obj.user,
|
||||||
|
"collection": obj.collection
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,12 +13,14 @@ class AgentStep(Record):
|
||||||
action = String()
|
action = String()
|
||||||
arguments = Map(String())
|
arguments = Map(String())
|
||||||
observation = String()
|
observation = String()
|
||||||
|
user = String() # User context for the step
|
||||||
|
|
||||||
class AgentRequest(Record):
|
class AgentRequest(Record):
|
||||||
question = String()
|
question = String()
|
||||||
state = String()
|
state = String()
|
||||||
group = Array(String())
|
group = Array(String())
|
||||||
history = Array(AgentStep())
|
history = Array(AgentStep())
|
||||||
|
user = String() # User context for multi-tenancy
|
||||||
|
|
||||||
class AgentResponse(Record):
|
class AgentResponse(Record):
|
||||||
answer = String()
|
answer = String()
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,8 @@ from ..core.topic import topic
|
||||||
|
|
||||||
class StructuredQueryRequest(Record):
|
class StructuredQueryRequest(Record):
|
||||||
question = String()
|
question = String()
|
||||||
|
user = String() # Cassandra keyspace identifier
|
||||||
|
collection = String() # Data collection identifier
|
||||||
|
|
||||||
class StructuredQueryResponse(Record):
|
class StructuredQueryResponse(Record):
|
||||||
error = Error()
|
error = Error()
|
||||||
|
|
|
||||||
|
|
@ -61,6 +61,10 @@ async def question(
|
||||||
"flow": flow_id,
|
"flow": flow_id,
|
||||||
"request": {
|
"request": {
|
||||||
"question": question,
|
"question": question,
|
||||||
|
"user": user,
|
||||||
|
"state": state or "",
|
||||||
|
"group": [],
|
||||||
|
"history": []
|
||||||
}
|
}
|
||||||
|
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -79,11 +79,11 @@ def format_table_data(rows, table_name, output_format):
|
||||||
else:
|
else:
|
||||||
return json.dumps({table_name: rows}, indent=2)
|
return json.dumps({table_name: rows}, indent=2)
|
||||||
|
|
||||||
def structured_query(url, flow_id, question, output_format='table'):
|
def structured_query(url, flow_id, question, user='trustgraph', collection='default', output_format='table'):
|
||||||
|
|
||||||
api = Api(url).flow().id(flow_id)
|
api = Api(url).flow().id(flow_id)
|
||||||
|
|
||||||
resp = api.structured_query(question=question)
|
resp = api.structured_query(question=question, user=user, collection=collection)
|
||||||
|
|
||||||
# Check for errors
|
# Check for errors
|
||||||
if "error" in resp and resp["error"]:
|
if "error" in resp and resp["error"]:
|
||||||
|
|
@ -132,6 +132,18 @@ def main():
|
||||||
help='Natural language question to execute',
|
help='Natural language question to execute',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--user',
|
||||||
|
default='trustgraph',
|
||||||
|
help='Cassandra keyspace identifier (default: trustgraph)'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--collection',
|
||||||
|
default='default',
|
||||||
|
help='Data collection identifier (default: default)'
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--format',
|
'--format',
|
||||||
choices=['table', 'json', 'csv'],
|
choices=['table', 'json', 'csv'],
|
||||||
|
|
@ -147,6 +159,8 @@ def main():
|
||||||
url=args.url,
|
url=args.url,
|
||||||
flow_id=args.flow_id,
|
flow_id=args.flow_id,
|
||||||
question=args.question,
|
question=args.question,
|
||||||
|
user=args.user,
|
||||||
|
collection=args.collection,
|
||||||
output_format=args.format,
|
output_format=args.format,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -148,7 +148,8 @@ class Processor(AgentService):
|
||||||
elif impl_id == "structured-query":
|
elif impl_id == "structured-query":
|
||||||
impl = functools.partial(
|
impl = functools.partial(
|
||||||
StructuredQueryImpl,
|
StructuredQueryImpl,
|
||||||
collection=data.get("collection")
|
collection=data.get("collection"),
|
||||||
|
user=None # User will be provided dynamically via context
|
||||||
)
|
)
|
||||||
arguments = StructuredQueryImpl.get_arguments()
|
arguments = StructuredQueryImpl.get_arguments()
|
||||||
else:
|
else:
|
||||||
|
|
@ -253,12 +254,26 @@ class Processor(AgentService):
|
||||||
|
|
||||||
logger.debug("Call React")
|
logger.debug("Call React")
|
||||||
|
|
||||||
|
# Create user-aware context wrapper that preserves the flow interface
|
||||||
|
# but adds user information for tools that need it
|
||||||
|
class UserAwareContext:
|
||||||
|
def __init__(self, flow, user):
|
||||||
|
self._flow = flow
|
||||||
|
self._user = user
|
||||||
|
|
||||||
|
def __call__(self, service_name):
|
||||||
|
client = self._flow(service_name)
|
||||||
|
# For structured query clients, store user context
|
||||||
|
if service_name == "structured-query-request":
|
||||||
|
client._current_user = self._user
|
||||||
|
return client
|
||||||
|
|
||||||
act = await temp_agent.react(
|
act = await temp_agent.react(
|
||||||
question = request.question,
|
question = request.question,
|
||||||
history = history,
|
history = history,
|
||||||
think = think,
|
think = think,
|
||||||
observe = observe,
|
observe = observe,
|
||||||
context = flow,
|
context = UserAwareContext(flow, request.user),
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"Action: {act}")
|
logger.debug(f"Action: {act}")
|
||||||
|
|
|
||||||
|
|
@ -87,9 +87,10 @@ class McpToolImpl:
|
||||||
|
|
||||||
# This tool implementation knows how to query structured data using natural language
|
# This tool implementation knows how to query structured data using natural language
|
||||||
class StructuredQueryImpl:
|
class StructuredQueryImpl:
|
||||||
def __init__(self, context, collection=None):
|
def __init__(self, context, collection=None, user=None):
|
||||||
self.context = context
|
self.context = context
|
||||||
self.collection = collection # For multi-tenant scenarios
|
self.collection = collection # For multi-tenant scenarios
|
||||||
|
self.user = user # User context for multi-tenancy
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_arguments():
|
def get_arguments():
|
||||||
|
|
@ -105,8 +106,13 @@ class StructuredQueryImpl:
|
||||||
client = self.context("structured-query-request")
|
client = self.context("structured-query-request")
|
||||||
logger.debug("Structured query question...")
|
logger.debug("Structured query question...")
|
||||||
|
|
||||||
|
# Get user from client context if available, otherwise use instance user or default
|
||||||
|
user = getattr(client, '_current_user', self.user or "trustgraph")
|
||||||
|
|
||||||
result = await client.structured_query(
|
result = await client.structured_query(
|
||||||
arguments.get("question")
|
question=arguments.get("question"),
|
||||||
|
user=user,
|
||||||
|
collection=self.collection or "default"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Format the result for the agent
|
# Format the result for the agent
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ def filter_tools_by_group_and_state(
|
||||||
# Apply defaults as specified in tech spec
|
# Apply defaults as specified in tech spec
|
||||||
if requested_groups is None:
|
if requested_groups is None:
|
||||||
requested_groups = ["default"]
|
requested_groups = ["default"]
|
||||||
if current_state is None:
|
if current_state is None or current_state == "":
|
||||||
current_state = "undefined"
|
current_state = "undefined"
|
||||||
|
|
||||||
logger.info(f"Filtering tools with groups={requested_groups}, state={current_state}")
|
logger.info(f"Filtering tools with groups={requested_groups}, state={current_state}")
|
||||||
|
|
|
||||||
|
|
@ -111,11 +111,10 @@ class Processor(FlowProcessor):
|
||||||
else:
|
else:
|
||||||
variables_as_strings[key] = str(value)
|
variables_as_strings[key] = str(value)
|
||||||
|
|
||||||
# Use standard TrustGraph user/collection values
|
# Use user/collection values from request
|
||||||
# These should eventually come from authentication/context
|
|
||||||
objects_request = ObjectsQueryRequest(
|
objects_request = ObjectsQueryRequest(
|
||||||
user="trustgraph", # Standard TrustGraph user
|
user=request.user,
|
||||||
collection="default", # Standard default collection
|
collection=request.collection,
|
||||||
query=nlp_response.graphql_query,
|
query=nlp_response.graphql_query,
|
||||||
variables=variables_as_strings,
|
variables=variables_as_strings,
|
||||||
operation_name=None
|
operation_name=None
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue