mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-06-09 06:45:13 +02:00
Extend use of user + collection fields (#503)
* Collection+user fields in structured query * User/collection in structured query & agent
This commit is contained in:
parent
a92050c411
commit
f22bf13aa6
15 changed files with 122 additions and 45 deletions
|
|
@ -59,8 +59,10 @@ class TestAgentStructuredQueryIntegration:
|
|||
# Create agent request
|
||||
request = AgentRequest(
|
||||
question="I need to find all customers from New York. Use the structured query tool to get this information.",
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
state="",
|
||||
group=None,
|
||||
history=[],
|
||||
user="test_user"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
|
|
@ -118,7 +120,8 @@ Args: {
|
|||
# Verify structured query was called
|
||||
mock_structured_client.structured_query.assert_called_once()
|
||||
call_args = mock_structured_client.structured_query.call_args
|
||||
question_arg = call_args[0][0] # positional argument
|
||||
# Check keyword arguments
|
||||
question_arg = call_args.kwargs.get("question") or call_args[1].get("question")
|
||||
assert "customers" in question_arg.lower()
|
||||
assert "new york" in question_arg.lower()
|
||||
|
||||
|
|
@ -140,8 +143,10 @@ Args: {
|
|||
|
||||
request = AgentRequest(
|
||||
question="Find data from a table that doesn't exist using structured query.",
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
state="",
|
||||
group=None,
|
||||
history=[],
|
||||
user="test_user"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
|
|
@ -198,8 +203,9 @@ Args: {
|
|||
# Agent should handle the error gracefully
|
||||
assert any(isinstance(resp, AgentResponse) for resp in responses)
|
||||
# The tool should have returned an error response that contains error info
|
||||
structured_query_call_args = mock_structured_client.structured_query.call_args[0]
|
||||
assert "table" in structured_query_call_args[0].lower() or "exist" in structured_query_call_args[0].lower()
|
||||
call_args = mock_structured_client.structured_query.call_args
|
||||
question_arg = call_args.kwargs.get("question") or call_args[1].get("question")
|
||||
assert "table" in question_arg.lower() or "exist" in question_arg.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_multi_step_structured_query_reasoning(self, agent_processor, structured_query_tool_config):
|
||||
|
|
@ -209,8 +215,10 @@ Args: {
|
|||
|
||||
request = AgentRequest(
|
||||
question="First find all customers from California, then tell me how many orders they have made.",
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
state="",
|
||||
group=None,
|
||||
history=[],
|
||||
user="test_user"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
|
|
@ -272,8 +280,9 @@ Args: {
|
|||
|
||||
assert any(isinstance(resp, AgentResponse) for resp in responses)
|
||||
# Verify the structured query was called with customer-related question
|
||||
call_args = mock_structured_client.structured_query.call_args[0]
|
||||
assert "california" in call_args[0].lower()
|
||||
call_args = mock_structured_client.structured_query.call_args
|
||||
question_arg = call_args.kwargs.get("question") or call_args[1].get("question")
|
||||
assert "california" in question_arg.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_structured_query_with_collection_parameter(self, agent_processor):
|
||||
|
|
@ -295,8 +304,10 @@ Args: {
|
|||
|
||||
request = AgentRequest(
|
||||
question="Query the sales data for recent transactions.",
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
state="",
|
||||
group=None,
|
||||
history=[],
|
||||
user="test_user"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
|
|
@ -359,8 +370,9 @@ Args: {
|
|||
|
||||
assert any(isinstance(resp, AgentResponse) for resp in responses)
|
||||
# Check the query was about sales/transactions
|
||||
call_args = mock_structured_client.structured_query.call_args[0]
|
||||
assert "sales" in call_args[0].lower() or "transactions" in call_args[0].lower()
|
||||
call_args = mock_structured_client.structured_query.call_args
|
||||
question_arg = call_args.kwargs.get("question") or call_args[1].get("question")
|
||||
assert "sales" in question_arg.lower() or "transactions" in question_arg.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_structured_query_tool_argument_validation(self, agent_processor, structured_query_tool_config):
|
||||
|
|
@ -390,8 +402,10 @@ Args: {
|
|||
|
||||
request = AgentRequest(
|
||||
question="Get customer information and format it nicely.",
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
state="",
|
||||
group=None,
|
||||
history=[],
|
||||
user="test_user"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
|
|
@ -463,5 +477,6 @@ Args: {
|
|||
assert any(isinstance(resp, AgentResponse) for resp in responses)
|
||||
|
||||
# Check that the query was about customer information
|
||||
call_args = mock_structured_client.structured_query.call_args[0]
|
||||
assert "customer" in call_args[0].lower()
|
||||
call_args = mock_structured_client.structured_query.call_args
|
||||
question_arg = call_args.kwargs.get("question") or call_args[1].get("question")
|
||||
assert "customer" in question_arg.lower()
|
||||
|
|
@ -41,7 +41,9 @@ class TestStructuredQueryServiceIntegration:
|
|||
"""Test complete structured query processing pipeline"""
|
||||
# Arrange - Create realistic query request
|
||||
request = StructuredQueryRequest(
|
||||
question="Show me all customers from California who have made purchases over $500"
|
||||
question="Show me all customers from California who have made purchases over $500",
|
||||
user="trustgraph",
|
||||
collection="default"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
|
|
|
|||
|
|
@ -44,7 +44,9 @@ class TestStructuredQueryProcessor:
|
|||
"""Test successful end-to-end query processing"""
|
||||
# Arrange
|
||||
request = StructuredQueryRequest(
|
||||
question="Show me all customers from New York"
|
||||
question="Show me all customers from New York",
|
||||
user="trustgraph",
|
||||
collection="default"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
|
|
|
|||
|
|
@ -132,11 +132,15 @@ class FlowInstance:
|
|||
input
|
||||
)["response"]
|
||||
|
||||
def agent(self, question):
|
||||
def agent(self, question, user="trustgraph", state="", group=None, history=None):
|
||||
|
||||
# The input consists of a question
|
||||
# The input consists of a question and optional context
|
||||
input = {
|
||||
"question": question
|
||||
"question": question,
|
||||
"user": user,
|
||||
"state": state,
|
||||
"group": group or [],
|
||||
"history": history or []
|
||||
}
|
||||
|
||||
return self.request(
|
||||
|
|
@ -456,20 +460,24 @@ class FlowInstance:
|
|||
|
||||
return response
|
||||
|
||||
def structured_query(self, question):
|
||||
def structured_query(self, question, user="trustgraph", collection="default"):
|
||||
"""
|
||||
Execute a natural language question against structured data.
|
||||
Combines NLP query conversion and GraphQL execution.
|
||||
|
||||
Args:
|
||||
question: Natural language question
|
||||
user: Cassandra keyspace identifier (default: "trustgraph")
|
||||
collection: Data collection identifier (default: "default")
|
||||
|
||||
Returns:
|
||||
dict with data and optional errors
|
||||
"""
|
||||
|
||||
input = {
|
||||
"question": question
|
||||
"question": question,
|
||||
"user": user,
|
||||
"collection": collection
|
||||
}
|
||||
|
||||
response = self.request(
|
||||
|
|
|
|||
|
|
@ -2,10 +2,12 @@ from . request_response_spec import RequestResponse, RequestResponseSpec
|
|||
from .. schema import StructuredQueryRequest, StructuredQueryResponse
|
||||
|
||||
class StructuredQueryClient(RequestResponse):
|
||||
async def structured_query(self, question, timeout=600):
|
||||
async def structured_query(self, question, user="trustgraph", collection="default", timeout=600):
|
||||
resp = await self.request(
|
||||
StructuredQueryRequest(
|
||||
question = question
|
||||
question = question,
|
||||
user = user,
|
||||
collection = collection
|
||||
),
|
||||
timeout=timeout
|
||||
)
|
||||
|
|
|
|||
|
|
@ -9,17 +9,19 @@ class AgentRequestTranslator(MessageTranslator):
|
|||
def to_pulsar(self, data: Dict[str, Any]) -> AgentRequest:
|
||||
return AgentRequest(
|
||||
question=data["question"],
|
||||
plan=data.get("plan", ""),
|
||||
state=data.get("state", ""),
|
||||
history=data.get("history", [])
|
||||
group=data.get("group", []),
|
||||
history=data.get("history", []),
|
||||
user=data.get("user", "trustgraph")
|
||||
)
|
||||
|
||||
def from_pulsar(self, obj: AgentRequest) -> Dict[str, Any]:
|
||||
return {
|
||||
"question": obj.question,
|
||||
"plan": obj.plan,
|
||||
"state": obj.state,
|
||||
"history": obj.history
|
||||
"group": obj.group,
|
||||
"history": obj.history,
|
||||
"user": obj.user
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -9,12 +9,16 @@ class StructuredQueryRequestTranslator(MessageTranslator):
|
|||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> StructuredQueryRequest:
|
||||
return StructuredQueryRequest(
|
||||
question=data.get("question", "")
|
||||
question=data.get("question", ""),
|
||||
user=data.get("user", "trustgraph"), # Default fallback
|
||||
collection=data.get("collection", "default") # Default fallback
|
||||
)
|
||||
|
||||
def from_pulsar(self, obj: StructuredQueryRequest) -> Dict[str, Any]:
|
||||
return {
|
||||
"question": obj.question
|
||||
"question": obj.question,
|
||||
"user": obj.user,
|
||||
"collection": obj.collection
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -13,12 +13,14 @@ class AgentStep(Record):
|
|||
action = String()
|
||||
arguments = Map(String())
|
||||
observation = String()
|
||||
user = String() # User context for the step
|
||||
|
||||
class AgentRequest(Record):
|
||||
question = String()
|
||||
state = String()
|
||||
group = Array(String())
|
||||
history = Array(AgentStep())
|
||||
user = String() # User context for multi-tenancy
|
||||
|
||||
class AgentResponse(Record):
|
||||
answer = String()
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ from ..core.topic import topic
|
|||
|
||||
class StructuredQueryRequest(Record):
|
||||
question = String()
|
||||
user = String() # Cassandra keyspace identifier
|
||||
collection = String() # Data collection identifier
|
||||
|
||||
class StructuredQueryResponse(Record):
|
||||
error = Error()
|
||||
|
|
|
|||
|
|
@ -61,6 +61,10 @@ async def question(
|
|||
"flow": flow_id,
|
||||
"request": {
|
||||
"question": question,
|
||||
"user": user,
|
||||
"state": state or "",
|
||||
"group": [],
|
||||
"history": []
|
||||
}
|
||||
|
||||
})
|
||||
|
|
|
|||
|
|
@ -79,11 +79,11 @@ def format_table_data(rows, table_name, output_format):
|
|||
else:
|
||||
return json.dumps({table_name: rows}, indent=2)
|
||||
|
||||
def structured_query(url, flow_id, question, output_format='table'):
|
||||
def structured_query(url, flow_id, question, user='trustgraph', collection='default', output_format='table'):
|
||||
|
||||
api = Api(url).flow().id(flow_id)
|
||||
|
||||
resp = api.structured_query(question=question)
|
||||
resp = api.structured_query(question=question, user=user, collection=collection)
|
||||
|
||||
# Check for errors
|
||||
if "error" in resp and resp["error"]:
|
||||
|
|
@ -132,6 +132,18 @@ def main():
|
|||
help='Natural language question to execute',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--user',
|
||||
default='trustgraph',
|
||||
help='Cassandra keyspace identifier (default: trustgraph)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--collection',
|
||||
default='default',
|
||||
help='Data collection identifier (default: default)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--format',
|
||||
choices=['table', 'json', 'csv'],
|
||||
|
|
@ -147,6 +159,8 @@ def main():
|
|||
url=args.url,
|
||||
flow_id=args.flow_id,
|
||||
question=args.question,
|
||||
user=args.user,
|
||||
collection=args.collection,
|
||||
output_format=args.format,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -148,7 +148,8 @@ class Processor(AgentService):
|
|||
elif impl_id == "structured-query":
|
||||
impl = functools.partial(
|
||||
StructuredQueryImpl,
|
||||
collection=data.get("collection")
|
||||
collection=data.get("collection"),
|
||||
user=None # User will be provided dynamically via context
|
||||
)
|
||||
arguments = StructuredQueryImpl.get_arguments()
|
||||
else:
|
||||
|
|
@ -253,12 +254,26 @@ class Processor(AgentService):
|
|||
|
||||
logger.debug("Call React")
|
||||
|
||||
# Create user-aware context wrapper that preserves the flow interface
|
||||
# but adds user information for tools that need it
|
||||
class UserAwareContext:
|
||||
def __init__(self, flow, user):
|
||||
self._flow = flow
|
||||
self._user = user
|
||||
|
||||
def __call__(self, service_name):
|
||||
client = self._flow(service_name)
|
||||
# For structured query clients, store user context
|
||||
if service_name == "structured-query-request":
|
||||
client._current_user = self._user
|
||||
return client
|
||||
|
||||
act = await temp_agent.react(
|
||||
question = request.question,
|
||||
history = history,
|
||||
think = think,
|
||||
observe = observe,
|
||||
context = flow,
|
||||
context = UserAwareContext(flow, request.user),
|
||||
)
|
||||
|
||||
logger.debug(f"Action: {act}")
|
||||
|
|
|
|||
|
|
@ -87,9 +87,10 @@ class McpToolImpl:
|
|||
|
||||
# This tool implementation knows how to query structured data using natural language
|
||||
class StructuredQueryImpl:
|
||||
def __init__(self, context, collection=None):
|
||||
def __init__(self, context, collection=None, user=None):
|
||||
self.context = context
|
||||
self.collection = collection # For multi-tenant scenarios
|
||||
self.user = user # User context for multi-tenancy
|
||||
|
||||
@staticmethod
|
||||
def get_arguments():
|
||||
|
|
@ -105,8 +106,13 @@ class StructuredQueryImpl:
|
|||
client = self.context("structured-query-request")
|
||||
logger.debug("Structured query question...")
|
||||
|
||||
# Get user from client context if available, otherwise use instance user or default
|
||||
user = getattr(client, '_current_user', self.user or "trustgraph")
|
||||
|
||||
result = await client.structured_query(
|
||||
arguments.get("question")
|
||||
question=arguments.get("question"),
|
||||
user=user,
|
||||
collection=self.collection or "default"
|
||||
)
|
||||
|
||||
# Format the result for the agent
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ def filter_tools_by_group_and_state(
|
|||
# Apply defaults as specified in tech spec
|
||||
if requested_groups is None:
|
||||
requested_groups = ["default"]
|
||||
if current_state is None:
|
||||
if current_state is None or current_state == "":
|
||||
current_state = "undefined"
|
||||
|
||||
logger.info(f"Filtering tools with groups={requested_groups}, state={current_state}")
|
||||
|
|
|
|||
|
|
@ -111,11 +111,10 @@ class Processor(FlowProcessor):
|
|||
else:
|
||||
variables_as_strings[key] = str(value)
|
||||
|
||||
# Use standard TrustGraph user/collection values
|
||||
# These should eventually come from authentication/context
|
||||
# Use user/collection values from request
|
||||
objects_request = ObjectsQueryRequest(
|
||||
user="trustgraph", # Standard TrustGraph user
|
||||
collection="default", # Standard default collection
|
||||
user=request.user,
|
||||
collection=request.collection,
|
||||
query=nlp_response.graphql_query,
|
||||
variables=variables_as_strings,
|
||||
operation_name=None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue