mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-10 15:52:36 +02:00
Feature/tool group (#484)
* Tech spec for tool group * Partial tool group implementation * Tool group tests
This commit is contained in:
parent
672e358b2f
commit
e74eb5d1ff
9 changed files with 1304 additions and 6 deletions
|
|
@ -18,6 +18,7 @@ from ... schema import AgentRequest, AgentResponse, AgentStep, Error
|
|||
|
||||
from . tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl, PromptImpl
|
||||
from . agent_manager import AgentManager
|
||||
from ..tool_filter import validate_tool_config, filter_tools_by_group_and_state, get_next_state
|
||||
|
||||
from . types import Final, Action, Tool, Argument
|
||||
|
||||
|
|
@ -142,6 +143,9 @@ class Processor(AgentService):
|
|||
f"Tool type {impl_id} not known"
|
||||
)
|
||||
|
||||
# Validate tool configuration
|
||||
validate_tool_config(data)
|
||||
|
||||
tools[name] = Tool(
|
||||
name=name,
|
||||
description=data.get("description"),
|
||||
|
|
@ -219,9 +223,24 @@ class Processor(AgentService):
|
|||
|
||||
await respond(r)
|
||||
|
||||
# Apply tool filtering based on request groups and state
|
||||
filtered_tools = filter_tools_by_group_and_state(
|
||||
tools=self.agent.tools,
|
||||
requested_groups=getattr(request, 'group', None),
|
||||
current_state=getattr(request, 'state', None)
|
||||
)
|
||||
|
||||
logger.info(f"Filtered from {len(self.agent.tools)} to {len(filtered_tools)} available tools")
|
||||
|
||||
# Create temporary agent with filtered tools
|
||||
temp_agent = AgentManager(
|
||||
tools=filtered_tools,
|
||||
additional_context=self.agent.additional_context
|
||||
)
|
||||
|
||||
logger.debug("Call React")
|
||||
|
||||
act = await self.agent.react(
|
||||
act = await temp_agent.react(
|
||||
question = request.question,
|
||||
history = history,
|
||||
think = think,
|
||||
|
|
@ -255,11 +274,17 @@ class Processor(AgentService):
|
|||
logger.debug("Send next...")
|
||||
|
||||
history.append(act)
|
||||
|
||||
# Handle state transitions if tool execution was successful
|
||||
next_state = request.state
|
||||
if act.name in filtered_tools:
|
||||
executed_tool = filtered_tools[act.name]
|
||||
next_state = get_next_state(executed_tool, request.state or "undefined")
|
||||
|
||||
r = AgentRequest(
|
||||
question=request.question,
|
||||
plan=request.plan,
|
||||
state=request.state,
|
||||
state=next_state,
|
||||
group=getattr(request, 'group', []),
|
||||
history=[
|
||||
AgentStep(
|
||||
thought=h.thought,
|
||||
|
|
|
|||
165
trustgraph-flow/trustgraph/agent/tool_filter.py
Normal file
165
trustgraph-flow/trustgraph/agent/tool_filter.py
Normal file
|
|
@ -0,0 +1,165 @@
|
|||
"""
|
||||
Tool filtering logic for the TrustGraph tool group system.
|
||||
|
||||
Provides functions to filter available tools based on group membership
|
||||
and execution state as defined in the tool-group tech spec.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def filter_tools_by_group_and_state(
|
||||
tools: Dict[str, Any],
|
||||
requested_groups: Optional[List[str]] = None,
|
||||
current_state: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Filter tools based on group membership and execution state.
|
||||
|
||||
Args:
|
||||
tools: Dictionary of tool_name -> tool_object
|
||||
requested_groups: List of groups requested (defaults to ["default"])
|
||||
current_state: Current execution state (defaults to "undefined")
|
||||
|
||||
Returns:
|
||||
Dictionary of filtered tools that match group and state criteria
|
||||
"""
|
||||
|
||||
# Apply defaults as specified in tech spec
|
||||
if requested_groups is None:
|
||||
requested_groups = ["default"]
|
||||
if current_state is None:
|
||||
current_state = "undefined"
|
||||
|
||||
logger.info(f"Filtering tools with groups={requested_groups}, state={current_state}")
|
||||
|
||||
filtered_tools = {}
|
||||
|
||||
for tool_name, tool in tools.items():
|
||||
if _is_tool_available(tool, requested_groups, current_state):
|
||||
filtered_tools[tool_name] = tool
|
||||
else:
|
||||
logger.debug(f"Tool {tool_name} filtered out")
|
||||
|
||||
logger.info(f"Filtered {len(tools)} tools to {len(filtered_tools)} available tools")
|
||||
return filtered_tools
|
||||
|
||||
|
||||
def _is_tool_available(
|
||||
tool: Any,
|
||||
requested_groups: List[str],
|
||||
current_state: str
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a tool is available based on group and state criteria.
|
||||
|
||||
Args:
|
||||
tool: Tool object with config attribute containing group/state metadata
|
||||
requested_groups: List of requested groups
|
||||
current_state: Current execution state
|
||||
|
||||
Returns:
|
||||
True if tool should be available, False otherwise
|
||||
"""
|
||||
|
||||
# Extract tool configuration
|
||||
config = getattr(tool, 'config', {})
|
||||
|
||||
# Get tool groups (default to ["default"] if not specified)
|
||||
tool_groups = config.get('group', ["default"])
|
||||
if not isinstance(tool_groups, list):
|
||||
tool_groups = [tool_groups]
|
||||
|
||||
# Get tool applicable states (default to all states if not specified)
|
||||
applicable_states = config.get('applicable-states', ["*"])
|
||||
if not isinstance(applicable_states, list):
|
||||
applicable_states = [applicable_states]
|
||||
|
||||
# Apply group filtering logic from tech spec:
|
||||
# Tool is available if intersection(tool_groups, requested_groups) is not empty
|
||||
# OR "*" is in requested_groups (wildcard access)
|
||||
group_match = (
|
||||
"*" in requested_groups or
|
||||
bool(set(tool_groups) & set(requested_groups))
|
||||
)
|
||||
|
||||
# Apply state filtering logic from tech spec:
|
||||
# Tool is available if current_state is in applicable_states
|
||||
# OR "*" is in applicable_states (available in all states)
|
||||
state_match = (
|
||||
"*" in applicable_states or
|
||||
current_state in applicable_states
|
||||
)
|
||||
|
||||
is_available = group_match and state_match
|
||||
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug(
|
||||
f"Tool availability check: tool_groups={tool_groups}, "
|
||||
f"requested_groups={requested_groups}, applicable_states={applicable_states}, "
|
||||
f"current_state={current_state}, group_match={group_match}, "
|
||||
f"state_match={state_match}, is_available={is_available}"
|
||||
)
|
||||
|
||||
return is_available
|
||||
|
||||
|
||||
def get_next_state(tool: Any, current_state: str) -> str:
|
||||
"""
|
||||
Get the next state after successful tool execution.
|
||||
|
||||
Args:
|
||||
tool: Tool object with config attribute
|
||||
current_state: Current execution state
|
||||
|
||||
Returns:
|
||||
Next state, or current_state if no transition is defined
|
||||
"""
|
||||
config = getattr(tool, 'config', {})
|
||||
if config is None:
|
||||
config = {}
|
||||
next_state = config.get('state')
|
||||
|
||||
if next_state:
|
||||
logger.debug(f"State transition: {current_state} -> {next_state}")
|
||||
return next_state
|
||||
else:
|
||||
logger.debug(f"No state transition defined, staying in {current_state}")
|
||||
return current_state
|
||||
|
||||
|
||||
def validate_tool_config(config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate tool configuration for group and state fields.
|
||||
|
||||
Args:
|
||||
config: Tool configuration dictionary
|
||||
|
||||
Raises:
|
||||
ValueError: If configuration is invalid
|
||||
"""
|
||||
|
||||
# Validate group field
|
||||
if 'group' in config:
|
||||
groups = config['group']
|
||||
if not isinstance(groups, list):
|
||||
raise ValueError("Tool 'group' field must be a list of strings")
|
||||
if not all(isinstance(g, str) for g in groups):
|
||||
raise ValueError("All group names must be strings")
|
||||
|
||||
# Validate state field
|
||||
if 'state' in config:
|
||||
state = config['state']
|
||||
if not isinstance(state, str):
|
||||
raise ValueError("Tool 'state' field must be a string")
|
||||
|
||||
# Validate applicable-states field
|
||||
if 'applicable-states' in config:
|
||||
states = config['applicable-states']
|
||||
if not isinstance(states, list):
|
||||
raise ValueError("Tool 'applicable-states' field must be a list of strings")
|
||||
if not all(isinstance(s, str) for s in states):
|
||||
raise ValueError("All state names must be strings")
|
||||
Loading…
Add table
Add a link
Reference in a new issue